diff --git a/build.ps1 b/build.ps1 new file mode 100644 index 000000000..0519ecba6 --- /dev/null +++ b/build.ps1 @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/build.sh b/build.sh new file mode 100644 index 000000000..0519ecba6 --- /dev/null +++ b/build.sh @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/cmd/assistant.go b/cmd/assistant.go new file mode 100644 index 000000000..5742b8641 --- /dev/null +++ b/cmd/assistant.go @@ -0,0 +1,114 @@ +package cmd + +import ( + "github.com/joho/godotenv" + "github.com/sashabaranov/go-openai" + "github.com/spf13/cobra" + "github.com/tmc/langchaingo/jsonschema" +) + +var assistantCmd = &cobra.Command{ + Use: "assistant", + Short: "DNA CLI assistant for database design and normalization", + Long: `Database Normalization Assistant (DNA) helps with database design and normalization. +It provides interactive guidance for schema design, normalization, and best practices.`, +} + +func init() { + // Load .env file if it exists + godotenv.Load() + + rootCmd.AddCommand(assistantCmd) + + // Add subcommands + assistantCmd.AddCommand(newAssistantChatCmd()) + assistantCmd.AddCommand(newAssistantDoctorCmd()) + assistantCmd.AddCommand(newAssistantSchemaCmd()) +} + +// Environment variables +const ( + EnvDNAProvider = "DNA_PROVIDER" + EnvDNAAPIKey = "DNA_API_KEY" + EnvDNAModel = "DNA_MODEL" + EnvDNATemperature = "DNA_TEMPERATURE" + + // Add Postgres connection constants + EnvPGHost = "PGHOST" + EnvPGPort = "PGPORT" + EnvPGDatabase = "PGDATABASE" + EnvPGUser = "PGUSER" + EnvPGPassword = "PGPASSWORD" +) + +// Default values +var defaultConfig = struct { + Provider string + Model string + Temperature float32 +}{ + Provider: "openai", + Model: "gpt-4", + Temperature: 0.7, +} + +var functions = []openai.FunctionDefinition{ + { + Name: "search_supabase_docs", + Description: "Search Supabase documentation", + Parameters: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "query": { + Type: jsonschema.String, + Description: "The search query", + }, + "topic": { + Type: jsonschema.String, + Description: "Optional topic to filter results", + }, + }, + }, + }, + { + Name: "analyze_schema", + // ... existing schema analysis function + }, + { + Name: "get_cli_help", + Description: "Get help information for DNA CLI commands", + Parameters: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "command": { + Type: jsonschema.String, + Description: "The command to get help for (e.g., 'db', 'assistant'). Empty for root help.", + }, + }, + }, + }, +} + +func newAssistantSchemaCmd() *cobra.Command { + return &cobra.Command{ + Use: "schema", + Short: "Get schema information about your database", + RunE: func(cmd *cobra.Command, args []string) error { + return nil // TODO: Implement schema command + }, + } +} + +func newAssistantCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "assistant", + Short: "DNA CLI assistant for database design and normalization", + } + + cmd.AddCommand( + newAssistantChatCmd(), + // Migration commands removed - these belong in main CLI + ) + + return cmd +} diff --git a/cmd/assistant_chat.go b/cmd/assistant_chat.go new file mode 100644 index 000000000..ccf601981 --- /dev/null +++ b/cmd/assistant_chat.go @@ -0,0 +1,464 @@ +package cmd + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + "github.com/fatih/color" + "github.com/joho/godotenv" + "github.com/sashabaranov/go-openai" + "github.com/spf13/cobra" +) + +type ChatSession struct { + client *openai.Client + history []openai.ChatCompletionMessage + isActive bool +} + +func newAssistantChatCmd() *cobra.Command { + return &cobra.Command{ + Use: "chat", + Short: "Start an interactive chat session with the DNA assistant", + RunE: func(cmd *cobra.Command, args []string) error { + return startChatSession() + }, + } +} + +func init() { + // Load environment variables from .env file + if err := godotenv.Load(); err != nil { + fmt.Println("Warning: .env file not found, using default values") + } + + // Initialize colors and other settings + assistantColor = color.New(color.FgGreen) + toolColor = color.New(color.FgYellow) + separator = strings.Repeat("=", 50) + toolSeparator = strings.Repeat("-", 30) + + // Check for required API key + if os.Getenv("DNA_API_KEY") == "" { + fmt.Println("DNA_API_KEY not set. Please set your API key: export DNA_API_KEY=your_api_key") + os.Exit(1) + } +} + +func startChatSession() error { + apiKey := os.Getenv(EnvDNAAPIKey) + if apiKey == "" { + return fmt.Errorf("DNA_API_KEY not set. Please set your API key: export DNA_API_KEY=your_api_key") + } + + session := &ChatSession{ + client: openai.NewClient(apiKey), + history: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: `You are a Database Normalization Assistant (DNA) for Supabase. +You help developers design and normalize their database schemas. +You provide guidance on: +1. Database design best practices +2. Normalization (1NF, 2NF, 3NF) +3. Supabase-specific features and optimizations +4. Schema analysis and improvements + +Always explain your reasoning and provide examples when relevant.`, + }, + }, + isActive: true, + } + + printWelcomeMessage() + return session.chatLoop() +} + +func printWelcomeMessage() { + fmt.Fprint(color.Output, helpText) +} + +func (s *ChatSession) chatLoop() error { + reader := bufio.NewReader(os.Stdin) + + for s.isActive { + fmt.Fprint(color.Output, prompt) + input, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("error reading input: %w", err) + } + + input = strings.TrimSpace(input) + if input == "" { + continue + } + + // First check hard-coded commands + switch input { + case "exit", "quit": + handleChatResponse("exit") + s.isActive = false + continue + case "help": + handleChatResponse("help") + continue + case "clear": + s.history = s.history[:1] + fmt.Fprintln(color.Output, bannerColor.Sprint("Chat history cleared.")) + continue + } + + // Then check for tool-specific commands + if strings.HasPrefix(input, "analyze") || + strings.HasPrefix(input, "verify") || + strings.HasPrefix(input, "search") { + if err := s.handleToolCommand(input); err != nil { + fmt.Fprintf(color.Output, "%s%v%s\n", + separator, + errorColor.Sprint(err), + separator) + } + continue + } + + // Finally, treat as chat message + if err := s.handleMessage(input); err != nil { + fmt.Fprintf(color.Output, "%s%v%s\n", + separator, + errorColor.Sprint(err), + separator) + } + } + + return nil +} + +func (s *ChatSession) handleToolCommand(input string) error { + // Parse command and arguments + parts := strings.Fields(input) + if len(parts) < 2 { + return fmt.Errorf("invalid command format") + } + + cmd := parts[0] + args := parts[1:] + + // Handle direct tool calls + switch cmd { + case "analyze": + return s.handleDirectAnalyze(args) + case "verify": + return s.handleDirectVerify(args) + case "search": + return s.handleDirectSearch(args) + default: + return fmt.Errorf("unknown command: %s", cmd) + } +} + +func (s *ChatSession) handleMessage(input string) error { + // Add user message to history + s.history = append(s.history, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: input, + }) + + // Keep getting responses until we get a normal message + for { + resp, err := s.client.CreateChatCompletion( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT4Turbo1106, + Messages: s.history, + Temperature: defaultConfig.Temperature, + Tools: []openai.Tool{ + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "search_supabase_docs", + Description: "Search Supabase documentation for relevant information", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "The search query for the documentation", + }, + "topic": map[string]interface{}{ + "type": "string", + "enum": []string{"database", "auth", "storage", "edge-functions", "realtime"}, + "description": "The specific topic to search within", + }, + }, + "required": []string{"query"}, + }, + }, + }, + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "analyze_schema", + Description: "Analyze database schema including tables, columns, and security policies", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "table": map[string]interface{}{ + "type": "string", + "description": "The specific table to analyze, or 'all' to list tables", + }, + }, + "required": []string{"table"}, + }, + }, + }, + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "analyze_functions", + Description: "Analyze stored procedures and functions in the database", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "The specific function to analyze, or 'all' to list functions", + }, + "schema": map[string]interface{}{ + "type": "string", + "enum": []string{"public", "auth", "storage", "graphql", "graphql_public"}, + "description": "The schema to search in", + "default": "public", + }, + }, + "required": []string{"name"}, + }, + }, + }, + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "get_cli_help", + Description: "Get help information for DNA CLI commands", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "command": map[string]interface{}{ + "type": "string", + "description": "The command to get help for (e.g., 'db', 'assistant'). Empty for root help.", + }, + }, + "required": []string{"command"}, + }, + }, + }, + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "create_migration", + Description: "Create a new empty migration file", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + "description": "Name for the migration (will be prefixed with timestamp)", + }, + }, + "required": []string{"name"}, + }, + }, + }, + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "write_migration", + Description: "Write SQL content to a migration file", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "version": map[string]interface{}{ + "type": "string", + "description": "Migration version/timestamp", + }, + "sql": map[string]interface{}{ + "type": "string", + "description": "SQL content to write to the migration file", + }, + }, + "required": []string{"version", "sql"}, + }, + }, + }, + { + Type: "function", + Function: &openai.FunctionDefinition{ + Name: "apply_migrations", + Description: "Apply pending migrations to the database", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "include_all": map[string]interface{}{ + "type": "boolean", + "description": "Include all migrations not found on remote history table", + "default": false, + }, + }, + }, + }, + }, + }, + }, + ) + if err != nil { + // Reset to initial state on error + s.history = []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: `You are a Database Normalization Assistant (DNA) for Supabase. +You help developers design and normalize their database schemas. +You provide guidance on: +1. Database design best practices +2. Normalization (1NF, 2NF, 3NF) +3. Supabase-specific features and optimizations +4. Schema analysis and improvements + +Always explain your reasoning and provide examples when relevant.`, + }, + } + fmt.Println("Chat history has been reset due to an error.") + return fmt.Errorf("error getting response: %w", err) + } + + // If response has no content, reset and return error + if resp.Choices[0].Message.Content == "" && resp.Choices[0].Message.ToolCalls == nil { + s.history = []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: `You are a Database Normalization Assistant (DNA) for Supabase. +You help developers design and normalize their database schemas. +You provide guidance on: +1. Database design best practices +2. Normalization (1NF, 2NF, 3NF) +3. Supabase-specific features and optimizations +4. Schema analysis and improvements + +Always explain your reasoning and provide examples when relevant.`, + }, + } + fmt.Println("Chat history has been reset due to empty response.") + return fmt.Errorf("received empty response from assistant") + } + + // If no tool calls, we're done + if resp.Choices[0].Message.ToolCalls == nil { + s.history = append(s.history, resp.Choices[0].Message) + fmt.Fprintln(color.Output, separator) + assistantColor.Fprintln(color.Output, resp.Choices[0].Message.Content) + fmt.Fprintln(color.Output, separator) + return nil + } + + // Handle tool calls + for _, toolCall := range resp.Choices[0].Message.ToolCalls { + fmt.Fprintln(color.Output, toolSeparator) + toolColor.Fprintf(color.Output, "Using tool: %s\n", toolCall.Function.Name) + + result, err := handleToolCall(&toolCall.Function) + if err != nil { + return fmt.Errorf("error handling tool call: %w", err) + } + + toolColor.Fprintln(color.Output, result.Result) + fmt.Fprintln(color.Output, toolSeparator) + + // Check for null content before adding to history + toolCallMessage := openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: "Using tool...", + ToolCalls: []openai.ToolCall{toolCall}, + } + toolResponseMessage := openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleTool, + Content: result.Result, + Name: toolCall.Function.Name, + ToolCallID: toolCall.ID, + } + + // Only append if content is not empty + if toolCallMessage.Content != "" { + s.history = append(s.history, toolCallMessage) + } + if toolResponseMessage.Content != "" { + s.history = append(s.history, toolResponseMessage) + } + } + // Loop continues to get next response + } +} + +// Direct tool handlers +func (s *ChatSession) handleDirectAnalyze(args []string) error { + if len(args) == 0 { + return fmt.Errorf("table name required") + } + result, err := handleToolCall(&openai.FunctionCall{ + Name: "analyze_schema", + Arguments: fmt.Sprintf(`{"table":"%s"}`, args[0]), + }) + if err != nil { + return err + } + toolColor.Fprintln(color.Output, result.Result) + return nil +} + +func (s *ChatSession) handleDirectVerify(args []string) error { + if len(args) == 0 { + return fmt.Errorf("file name required") + } + // TODO: Implement verify logic + return fmt.Errorf("verify not implemented yet") +} + +func (s *ChatSession) handleDirectSearch(args []string) error { + if len(args) == 0 { + return fmt.Errorf("search query required") + } + result, err := handleToolCall(&openai.FunctionCall{ + Name: "search_supabase_docs", + Arguments: fmt.Sprintf(`{"query":"%s"}`, args[0]), + }) + if err != nil { + return err + } + toolColor.Fprintln(color.Output, result.Result) + return nil +} + +// Add a new command to reset the chat +func (s *ChatSession) handleCommand(cmd string, args []string) error { + switch cmd { + case "reset": + s.history = []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: `You are a Database Normalization Assistant (DNA) for Supabase. +You help developers design and normalize their database schemas. +You provide guidance on: +1. Database design best practices +2. Normalization (1NF, 2NF, 3NF) +3. Supabase-specific features and optimizations +4. Schema analysis and improvements + +Always explain your reasoning and provide examples when relevant.`, + }, + } + fmt.Println("Chat history has been reset.") + return nil + // ... other commands ... + } + return nil +} diff --git a/cmd/assistant_config.go b/cmd/assistant_config.go new file mode 100644 index 000000000..da1748836 --- /dev/null +++ b/cmd/assistant_config.go @@ -0,0 +1,99 @@ +package cmd + +import ( + "fmt" + + "github.com/spf13/afero" + "github.com/spf13/cobra" + "github.com/supabase/cli/internal/dna/config" +) + +func newAssistantConfigCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "config", + Short: "Manage DNA assistant configuration", + } + + cmd.AddCommand(newAssistantConfigSetCmd()) + cmd.AddCommand(newAssistantConfigGetCmd()) + + return cmd +} + +func newAssistantConfigSetCmd() *cobra.Command { + var apiKey, provider, model string + var temperature float32 + + cmd := &cobra.Command{ + Use: "set", + Short: "Set DNA assistant configuration values", + RunE: func(cmd *cobra.Command, args []string) error { + fs := afero.NewOsFs() + cfg, err := config.Load(fs) + if err != nil { + // If config doesn't exist, start with defaults + cfg = &config.DefaultConfig + } + + // Update only provided values + if apiKey != "" { + cfg.APIKey = apiKey + } + if provider != "" { + cfg.Provider = provider + } + if model != "" { + cfg.Model = model + } + if cmd.Flags().Changed("temperature") { + cfg.Temperature = temperature + } + + if err := cfg.Save(fs); err != nil { + return fmt.Errorf("error saving config: %w", err) + } + + fmt.Println("Configuration updated successfully") + return nil + }, + } + + cmd.Flags().StringVar(&apiKey, "api-key", "", "OpenAI API key") + cmd.Flags().StringVar(&provider, "provider", "", "AI provider (openai)") + cmd.Flags().StringVar(&model, "model", "", "AI model (e.g., gpt-4)") + cmd.Flags().Float32Var(&temperature, "temperature", 0.7, "Model temperature (0.0-1.0)") + + return cmd +} + +func newAssistantConfigGetCmd() *cobra.Command { + return &cobra.Command{ + Use: "get", + Short: "Display current DNA assistant configuration", + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load(afero.NewOsFs()) + if err != nil { + return fmt.Errorf("error loading config: %w", err) + } + + fmt.Printf("Current configuration:\n") + fmt.Printf("API Key: %s\n", maskAPIKey(cfg.APIKey)) + fmt.Printf("Provider: %s\n", cfg.Provider) + fmt.Printf("Model: %s\n", cfg.Model) + fmt.Printf("Temperature: %.2f\n", cfg.Temperature) + + return nil + }, + } +} + +func maskAPIKey(key string) string { + if len(key) <= 8 { + return "********" + } + return key[:4] + "..." + key[len(key)-4:] +} + +func init() { + assistantCmd.AddCommand(newAssistantConfigCmd()) +} diff --git a/cmd/assistant_doctor.go b/cmd/assistant_doctor.go new file mode 100644 index 000000000..5fd3f92db --- /dev/null +++ b/cmd/assistant_doctor.go @@ -0,0 +1,143 @@ +package cmd + +import ( + "context" + "fmt" + "os" + + "github.com/sashabaranov/go-openai" + "github.com/spf13/cobra" +) + +type HealthCheck struct { + Name string + Check func() (string, error) + FixMessage string +} + +func newAssistantDoctorCmd() *cobra.Command { + return &cobra.Command{ + Use: "doctor", + Short: "Check environment setup for the DNA assistant", + RunE: func(cmd *cobra.Command, args []string) error { + return runEnvironmentChecks() + }, + } +} + +func runEnvironmentChecks() error { + checks := []HealthCheck{ + { + Name: "Environment Variables", + Check: func() (string, error) { + apiKey := os.Getenv(EnvDNAAPIKey) + if apiKey == "" { + return "", fmt.Errorf("DNA_API_KEY not set") + } + + // Validate API key by making a test request + client := openai.NewClient(apiKey) + _, err := client.CreateChatCompletion( + context.Background(), + openai.ChatCompletionRequest{ + Model: "gpt-3.5-turbo", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Test connection", + }, + }, + }, + ) + if err != nil { + return "", fmt.Errorf("invalid API key: %v", err) + } + + provider := os.Getenv(EnvDNAProvider) + if provider == "" { + return "Using default provider: openai", nil + } + return fmt.Sprintf("Using provider: %s", provider), nil + }, + FixMessage: `Export the required environment variables: +export DNA_API_KEY=your_api_key +export DNA_PROVIDER=openai # Optional, defaults to openai`, + }, + { + Name: "Supabase Project", + Check: func() (string, error) { + if !isSupabaseProject() { + return "", fmt.Errorf("not in a Supabase project directory") + } + return "Supabase project detected", nil + }, + FixMessage: `Run this command from within a Supabase project directory. +Or initialize a new project with: supabase init`, + }, + { + Name: "Database Connection", + Check: func() (string, error) { + config, err := loadDatabaseConfig() + if err != nil { + return "", fmt.Errorf("failed to load database config: %w", err) + } + if err := checkDatabaseConnection(config); err != nil { + return "", fmt.Errorf("database connection failed: %w", err) + } + return "Database connection successful", nil + }, + FixMessage: `Ensure your database is running: supabase start +Check your database connection settings in supabase/config.toml`, + }, + } + + // Run all checks + hasErrors := false + for _, check := range checks { + fmt.Printf("Checking %s... ", check.Name) + msg, err := check.Check() + if err != nil { + hasErrors = true + fmt.Println("āŒ") + fmt.Printf("Error: %v\n", err) + if check.FixMessage != "" { + fmt.Printf("Fix:\n%s\n", check.FixMessage) + } + } else { + fmt.Println("āœ…") + if msg != "" { + fmt.Printf("Info: %s\n", msg) + } + } + fmt.Println() + } + + if hasErrors { + return fmt.Errorf("one or more checks failed") + } + return nil +} + +func isSupabaseProject() bool { + _, err := os.Stat("supabase") + return err == nil +} + +func loadDatabaseConfig() (map[string]string, error) { + // This is a placeholder - in the real implementation, + // we would parse the supabase/config.toml file + config := map[string]string{ + "host": "localhost", + "port": "54322", + "user": "postgres", + "password": "postgres", + "database": "postgres", + } + return config, nil +} + +func checkDatabaseConnection(config map[string]string) error { + // This is a placeholder - in the real implementation, + // we would attempt to connect to the database + return nil +} diff --git a/cmd/assistant_schema.go b/cmd/assistant_schema.go new file mode 100644 index 000000000..88a6deb27 --- /dev/null +++ b/cmd/assistant_schema.go @@ -0,0 +1,338 @@ +package cmd + +import ( + "database/sql" + "fmt" + "os" + + _ "github.com/lib/pq" // Postgres driver +) + +// analyzeSchema handles the main schema analysis +func analyzeSchema(table string) (*ToolResponse, error) { + if table == "all" { + tables, err := listTables() + if err != nil { + return &ToolResponse{ + Success: false, + Result: fmt.Sprintf("Error listing tables: %v", err), + }, nil + } + return &ToolResponse{ + Success: true, + Result: fmt.Sprintf("Found tables:\n%s", tables), + }, nil + } + + schema, err := getTableSchema(table) + if err != nil { + return &ToolResponse{ + Success: false, + Result: fmt.Sprintf("Error getting schema for %s: %v", table, err), + }, nil + } + + rls, err := getTableRLS(table) + if err != nil { + return &ToolResponse{ + Success: false, + Result: fmt.Sprintf("Error getting RLS for %s: %v", table, err), + }, nil + } + + return &ToolResponse{ + Success: true, + Result: fmt.Sprintf("Table: %s\n\nSchema:\n%s\n\nRow Level Security:\n%s", + table, schema, rls), + }, nil +} + +func listTables() (string, error) { + db, err := getDB() + if err != nil { + return "", fmt.Errorf("error connecting to database: %w", err) + } + defer db.Close() + + rows, err := db.Query(` + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public' + ORDER BY table_name; + `) + if err != nil { + return "", fmt.Errorf("error querying tables: %w", err) + } + defer rows.Close() + + var tables string + for rows.Next() { + var tableName string + if err := rows.Scan(&tableName); err != nil { + return "", fmt.Errorf("error scanning row: %w", err) + } + tables += fmt.Sprintf("Found table: %s\n", tableName) + } + + if tables == "" { + return "No tables found.", nil + } + + return tables, nil +} + +func getTableSchema(table string) (string, error) { + db, err := getDB() + if err != nil { + return "", fmt.Errorf("error connecting to database: %w", err) + } + defer db.Close() + + rows, err := db.Query(` + SELECT + column_name, + data_type, + is_nullable, + COALESCE(column_default, 'NULL') as column_default, + CASE + WHEN is_identity = 'YES' THEN 'IDENTITY' + WHEN is_generated = 'ALWAYS' THEN 'GENERATED' + ELSE '' + END as special + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = $1 + ORDER BY ordinal_position; + `, table) + if err != nil { + return "", fmt.Errorf("error querying schema: %w", err) + } + defer rows.Close() + + var schema string + for rows.Next() { + var colName, dataType, nullable, defaultVal, special string + if err := rows.Scan(&colName, &dataType, &nullable, &defaultVal, &special); err != nil { + return "", fmt.Errorf("error scanning row: %w", err) + } + schema += fmt.Sprintf("Column: %s\nType: %s\nNullable: %s\nDefault: %s\nSpecial: %s\n\n", + colName, dataType, nullable, defaultVal, special) + } + + if schema == "" { + return "No columns found.", nil + } + + return schema, nil +} + +// getDB returns a database connection +func getDB() (*sql.DB, error) { + connStr := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", + os.Getenv(EnvPGHost), + os.Getenv(EnvPGPort), + os.Getenv(EnvPGUser), + os.Getenv(EnvPGPassword), + os.Getenv(EnvPGDatabase)) + + return sql.Open("postgres", connStr) +} + +// Convert getTableRLS to use database/sql +func getTableRLS(table string) (string, error) { + db, err := getDB() + if err != nil { + return "", fmt.Errorf("error connecting to database: %w", err) + } + defer db.Close() + + rows, err := db.Query(` + SELECT + pol.polname as policy_name, + CASE WHEN pol.polpermissive THEN 'PERMISSIVE' ELSE 'RESTRICTIVE' END as policy_type, + CASE + WHEN pol.polroles = '{0}' THEN 'PUBLIC' + ELSE array_to_string(ARRAY( + SELECT rolname + FROM pg_roles + WHERE oid = ANY(pol.polroles) + ), ',') + END as roles, + pol.polcmd as operation, + pg_get_expr(pol.polqual, pol.polrelid) as using_expression, + pg_get_expr(pol.polwithcheck, pol.polrelid) as with_check_expression + FROM pg_policy pol + JOIN pg_class cls ON pol.polrelid = cls.oid + WHERE cls.relname = $1; + `, table) + if err != nil { + return "", fmt.Errorf("error querying RLS: %w", err) + } + defer rows.Close() + + var policies string + for rows.Next() { + var name, ptype, roles, op, using, check string + if err := rows.Scan(&name, &ptype, &roles, &op, &using, &check); err != nil { + return "", fmt.Errorf("error scanning row: %w", err) + } + policies += fmt.Sprintf("Policy: %s\nType: %s\nRoles: %s\nOperation: %s\nUsing: %s\nWith Check: %s\n\n", + name, ptype, roles, op, using, check) + } + + return policies, nil +} + +func analyzeFunctions(name string, schema string) (*ToolResponse, error) { + if name == "all" { + funcs, err := listFunctions(schema) + if err != nil { + return &ToolResponse{ + Success: false, + Result: fmt.Sprintf("Error listing functions: %v", err), + }, nil + } + return &ToolResponse{ + Success: true, + Result: fmt.Sprintf("Found functions in schema %s:\n%s", schema, funcs), + }, nil + } + + details, err := getFunctionDetails(name, schema) + if err != nil { + return &ToolResponse{ + Success: false, + Result: fmt.Sprintf("Error getting function details for %s: %v", name, err), + }, nil + } + + security, err := getFunctionSecurity(name, schema) + if err != nil { + return &ToolResponse{ + Success: false, + Result: fmt.Sprintf("Error getting function security for %s: %v", name, err), + }, nil + } + + return &ToolResponse{ + Success: true, + Result: fmt.Sprintf("Function: %s\n\nDefinition:\n%s\n\nSecurity:\n%s", + name, details, security), + }, nil +} + +func listFunctions(schema string) (string, error) { + db, err := getDB() + if err != nil { + return "", fmt.Errorf("error connecting to database: %w", err) + } + defer db.Close() + + rows, err := db.Query(` + SELECT + p.proname as name, + pg_get_function_arguments(p.oid) as arguments, + t.typname as return_type + FROM pg_proc p + JOIN pg_type t ON p.prorettype = t.oid + JOIN pg_namespace n ON p.pronamespace = n.oid + WHERE n.nspname = $1 + ORDER BY p.proname; + `, schema) + if err != nil { + return "", fmt.Errorf("error querying functions: %w", err) + } + defer rows.Close() + + var funcs string + for rows.Next() { + var name, args, retType string + if err := rows.Scan(&name, &args, &retType); err != nil { + return "", fmt.Errorf("error scanning row: %w", err) + } + funcs += fmt.Sprintf("%s(%s) RETURNS %s\n", name, args, retType) + } + + return funcs, nil +} + +func getFunctionDetails(name, schema string) (string, error) { + db, err := getDB() + if err != nil { + return "", fmt.Errorf("error connecting to database: %w", err) + } + defer db.Close() + + rows, err := db.Query(` + SELECT + p.proname as name, + pg_get_functiondef(p.oid) as definition, + pg_get_function_arguments(p.oid) as arguments, + t.typname as return_type, + p.prosrc as source + FROM pg_proc p + JOIN pg_type t ON p.prorettype = t.oid + JOIN pg_namespace n ON p.pronamespace = n.oid + WHERE n.nspname = $1 + AND p.proname = $2; + `, schema, name) + if err != nil { + return "", fmt.Errorf("error querying function details: %w", err) + } + defer rows.Close() + + var details string + for rows.Next() { + var name, def, args, retType, src string + if err := rows.Scan(&name, &def, &args, &retType, &src); err != nil { + return "", fmt.Errorf("error scanning row: %w", err) + } + details += fmt.Sprintf("Name: %s\nArguments: %s\nReturns: %s\nDefinition:\n%s\nSource:\n%s\n", + name, args, retType, def, src) + } + + return details, nil +} + +func getFunctionSecurity(name, schema string) (string, error) { + db, err := getDB() + if err != nil { + return "", fmt.Errorf("error connecting to database: %w", err) + } + defer db.Close() + + rows, err := db.Query(` + SELECT + p.proname as name, + CASE + WHEN p.prosecdef THEN 'SECURITY DEFINER' + ELSE 'SECURITY INVOKER' + END as security, + CASE WHEN p.proleakproof THEN 'LEAKPROOF' ELSE 'NOT LEAKPROOF' END as leakproof, + array_to_string(ARRAY( + SELECT rolname + FROM pg_roles + WHERE oid = ANY(p.proacl::regrole[]) + ), ',') as grantees + FROM pg_proc p + JOIN pg_namespace n ON p.pronamespace = n.oid + WHERE n.nspname = $1 + AND p.proname = $2; + `, schema, name) + if err != nil { + return "", fmt.Errorf("error querying function security: %w", err) + } + defer rows.Close() + + var security string + for rows.Next() { + var name, sec, leak, grantees string + if err := rows.Scan(&name, &sec, &leak, &grantees); err != nil { + return "", fmt.Errorf("error scanning row: %w", err) + } + security += fmt.Sprintf("Security: %s\nLeakproof: %s\nGrantees: %s\n", + sec, leak, grantees) + } + + return security, nil +} diff --git a/cmd/assistant_search.go b/cmd/assistant_search.go new file mode 100644 index 000000000..21a0e44a8 --- /dev/null +++ b/cmd/assistant_search.go @@ -0,0 +1,68 @@ +package cmd + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" +) + +type SearchResult struct { + Title string `json:"title"` + URL string `json:"url"` + Description string `json:"description"` +} + +func searchDocs(query string, topic string) (*ToolResponse, error) { + // Build search URL + baseURL := "https://supabase.com/docs/search" + params := url.Values{} + params.Add("q", query) + if topic != "" { + params.Add("topic", topic) + } + + // Make request + resp, err := http.Get(baseURL + "?" + params.Encode()) + if err != nil { + return &ToolResponse{ + Success: false, + Result: fmt.Sprintf("Error searching docs: %v", err), + }, nil + } + defer resp.Body.Close() + + // Read response + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return &ToolResponse{ + Success: false, + Result: fmt.Sprintf("Error reading response: %v", err), + }, nil + } + + // Format results nicely + var results []SearchResult + if err := json.Unmarshal(body, &results); err != nil { + return &ToolResponse{ + Success: false, + Result: fmt.Sprintf("Error parsing results: %v", err), + }, nil + } + + // Build formatted response + var formattedResult string + for _, result := range results { + formattedResult += fmt.Sprintf("\nšŸ“š %s\nšŸ”— %s\nšŸ“ %s\n", + result.Title, + result.URL, + result.Description, + ) + } + + return &ToolResponse{ + Success: true, + Result: formattedResult, + }, nil +} diff --git a/cmd/assistant_test.go b/cmd/assistant_test.go new file mode 100644 index 000000000..881561613 --- /dev/null +++ b/cmd/assistant_test.go @@ -0,0 +1,33 @@ +package cmd + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDNAAssistant(t *testing.T) { + // Skip if DNA_API_KEY is not set to avoid running API tests in CI + if os.Getenv("DNA_API_KEY") == "" { + t.Skip("DNA_API_KEY not set") + } + + t.Run("doctor command", func(t *testing.T) { + err := runEnvironmentChecks() + assert.NoError(t, err, "doctor command should pass when environment is properly set up") + }) +} + +// TestDNAAssistantNoAuth tests behavior when auth is not configured +func TestDNAAssistantNoAuth(t *testing.T) { + // Temporarily clear DNA_API_KEY + originalKey := os.Getenv("DNA_API_KEY") + os.Setenv("DNA_API_KEY", "") + defer os.Setenv("DNA_API_KEY", originalKey) + + t.Run("doctor command without auth", func(t *testing.T) { + err := runEnvironmentChecks() + assert.Error(t, err, "doctor command should fail when DNA_API_KEY is not set") + }) +} diff --git a/cmd/assistant_tools.go b/cmd/assistant_tools.go new file mode 100644 index 000000000..6bccdb886 --- /dev/null +++ b/cmd/assistant_tools.go @@ -0,0 +1,716 @@ +package cmd + +import ( + "database/sql" + "encoding/json" + "fmt" + "os" + "runtime" + "strconv" + "strings" + + _ "github.com/lib/pq" + "github.com/sashabaranov/go-openai" + "github.com/spf13/afero" + "github.com/supabase/cli/internal/migration/new" + "github.com/supabase/cli/internal/utils/flags" +) + +var LineSep string + +func init() { + if runtime.GOOS == "windows" { + LineSep = "\r\n" + } else { + LineSep = "\n" + } +} + +const ( + String = "string" + Object = "object" + Bool = "boolean" +) + +const defaultMigrationsPath = "supabase/migrations" + +type Definition struct { + Type string `json:"type"` + Description string `json:"description,omitempty"` + Properties map[string]Definition `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` + Enum []string `json:"enum,omitempty"` +} + +type ToolResponse struct { + Result string `json:"result"` + Success bool `json:"success"` +} + +// can this be put inside the other tools list? +var migrationTools = []openai.FunctionDefinition{ + { + Name: "create_migration", + Description: "Create a new SQL migration file with a timestamped name in the supabase/migrations directory", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{ + "name": { + Type: String, + Description: "Descriptive name for the migration (e.g., 'add_user_table', 'update_post_constraints')", + }, + "sql": { + Type: String, + Description: "SQL statements to be executed in the migration (should include BEGIN and COMMIT)", + }, + }, + Required: []string{"name", "sql"}, + }, + }, + { + Name: "apply_migration", + Description: "Apply all pending migrations to the database in sequential order", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{ + "target": { + Type: String, + Enum: []string{"local", "linked"}, + Description: "Target database: 'local' for development or 'linked' for remote Supabase project", + }, + }, + }, + }, + { + Name: "edit_migration", + Description: "Modify an existing migration file in the supabase/migrations directory", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{ + "filename": { + Type: String, + Description: "Full filename of the migration to edit (e.g., '20250220171330_add_likes.sql')", + }, + "sql": { + Type: String, + Description: "New SQL content to replace the existing migration (should include BEGIN and COMMIT)", + }, + }, + Required: []string{"filename", "sql"}, + }, + }, +} + +// Combine all tools +var tools = []openai.FunctionDefinition{ + // Existing tools + { + Name: "search_supabase_docs", + Description: "Search Supabase documentation", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{ + "query": { + Type: String, + Description: "Query to search for", + }, + "topic": { + Type: String, + Description: "Topic to search within", + }, + }, + Required: []string{"query", "topic"}, + }, + }, + { + Name: "analyze_schema", + Description: "Analyze database schema", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{ + "table": { + Type: String, + Description: "Name of the table to analyze", + }, + }, + Required: []string{"table"}, + }, + }, + { + Name: "analyze_functions", + Description: "Analyze stored procedures", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{ + "name": { + Type: String, + Description: "Name of the function", + }, + "schema": { + Type: String, + Description: "Schema of the function", + }, + }, + Required: []string{"name", "schema"}, + }, + }, + { + Name: "get_cli_help", + Description: "Get CLI help information", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{ + "command": { + Type: String, + Description: "Command to get help for", + }, + }, + Required: []string{"command"}, + }, + }, + // Add migration tools + { + Name: "create_migration", + Description: "Create a new SQL migration file with a timestamped name in the supabase/migrations directory", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{ + "name": { + Type: String, + Description: "Descriptive name for the migration (e.g., 'add_user_table', 'update_post_constraints')", + }, + "sql": { + Type: String, + Description: "SQL statements to be executed in the migration (should include BEGIN and COMMIT)", + }, + }, + Required: []string{"name", "sql"}, + }, + }, + { + Name: "apply_migration", + Description: "Apply all pending migrations to the database in sequential order", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{ + "target": { + Type: String, + Enum: []string{"local", "linked"}, + Description: "Target database: 'local' for development or 'linked' for remote Supabase project", + }, + }, + }, + }, + { + Name: "edit_migration", + Description: "Modify an existing migration file in the supabase/migrations directory", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{ + "filename": { + Type: String, + Description: "Full filename of the migration to edit (e.g., '20250220171330_add_likes.sql')", + }, + "sql": { + Type: String, + Description: "New SQL content to replace the existing migration (should include BEGIN and COMMIT)", + }, + }, + Required: []string{"filename", "sql"}, + }, + }, + { + Name: "write_migration", + Description: "Write SQL content to a migration file", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{ + "version": { + Type: String, + Description: "IMPORTANT: Use the EXACT filename from create_migration, including both timestamp and name (e.g., if create_migration returns '20250221204247_enable_rls_and_restrict_access.sql', use that complete filename)", + }, + "sql": { + Type: String, + Description: "SQL content to write to the migration file (should include BEGIN and COMMIT)", + }, + }, + Required: []string{"version", "sql"}, + }, + }, + { + Name: "list_migrations", + Description: "List all migration files in the supabase/migrations directory", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{}, + }, + }, + { + Name: "read_migration", + Description: "Read the contents of a specific migration file", + Parameters: Definition{ + Type: String, + Properties: map[string]Definition{ + "version": { + Type: String, + Description: "Version/timestamp of the migration file (e.g., '20250220171330')", + }, + }, + Required: []string{"version"}, + }, + }, +} + +func handleToolCall(call *openai.FunctionCall) (*ToolResponse, error) { + switch call.Name { + case "search_supabase_docs": + var params struct { + Query string `json:"query"` + Topic string `json:"topic"` + } + if err := json.Unmarshal([]byte(call.Arguments), ¶ms); err != nil { + return nil, err + } + return searchDocs(params.Query, params.Topic) + + case "analyze_schema": + var params struct { + Table string `json:"table"` + } + if err := json.Unmarshal([]byte(call.Arguments), ¶ms); err != nil { + return nil, err + } + return analyzeSchema(params.Table) + + case "analyze_functions": + var params struct { + Name string `json:"name"` + Schema string `json:"schema"` + } + if err := json.Unmarshal([]byte(call.Arguments), ¶ms); err != nil { + return nil, err + } + return analyzeFunctions(params.Name, params.Schema) + + case "create_migration": + var params struct { + Name string `json:"name"` + SQL string `json:"sql"` + } + if err := json.Unmarshal([]byte(call.Arguments), ¶ms); err != nil { + return nil, err + } + + // Create an empty reader instead of using os.Stdin + emptyReader := strings.NewReader(params.SQL) + if err := new.Run(params.Name, emptyReader, afero.NewOsFs()); err != nil { + return nil, err + } + + // List migrations to find the one we just created + fs := afero.NewOsFs() + migrationsDir, err := findMigrationsDir(fs) + if err != nil { + return nil, err + } + files, err := afero.ReadDir(fs, migrationsDir) + if err != nil { + return nil, fmt.Errorf("failed to read migrations directory: %w", err) + } + + // Find the most recent migration file that matches our name + var newMigrationFile string + var latestTime int64 + for _, f := range files { + if !f.IsDir() && strings.HasSuffix(f.Name(), "_"+params.Name+".sql") { + // FileInfo already has ModTime() + if f.ModTime().Unix() > latestTime { + latestTime = f.ModTime().Unix() + newMigrationFile = f.Name() + } + } + } + + if newMigrationFile == "" { + return nil, fmt.Errorf("failed to find newly created migration file for: %s", params.Name) + } + + return &ToolResponse{ + Result: fmt.Sprintf("Created new migration:%s%s%s%s%s", + LineSep, + newMigrationFile, + LineSep, + newMigrationFile, + LineSep, + newMigrationFile), + Success: true, + }, nil + + case "write_migration": + var params struct { + Version string `json:"version"` + SQL string `json:"sql"` + } + if err := json.Unmarshal([]byte(call.Arguments), ¶ms); err != nil { + return nil, err + } + + fs := afero.NewOsFs() + migrationsDir, err := findMigrationsDir(fs) + if err != nil { + return nil, err + } + + // List available migrations first + files, err := afero.ReadDir(fs, migrationsDir) + if err != nil { + return nil, fmt.Errorf("failed to read migrations directory: %w", err) + } + + var migrations []string + for _, f := range files { + if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") { + migrations = append(migrations, f.Name()) + } + } + + // Find best matching file + matchingFile, err := findMigrationFile(fs, params.Version) + if err != nil { + return nil, fmt.Errorf("available migrations:\n%s\nerror: %w", + strings.Join(migrations, "\n"), err) + } + + migrationPath := fmt.Sprintf("%s/%s", migrationsDir, matchingFile) + if err := afero.WriteFile(fs, migrationPath, []byte(params.SQL), 0644); err != nil { + return nil, fmt.Errorf("failed to write migration file: %w", err) + } + + return &ToolResponse{ + Result: fmt.Sprintf("Successfully wrote SQL to migration file: %s%sContent:%s%s", + migrationPath, + LineSep, + LineSep, + params.SQL), + Success: true, + }, nil + + case "apply_migrations": + var params struct { + IncludeAll bool `json:"include_all"` + } + if err := json.Unmarshal([]byte(call.Arguments), ¶ms); err != nil { + return nil, err + } + + fs := afero.NewOsFs() + migrationsDir, err := findMigrationsDir(fs) + if err != nil { + return nil, err + } + + // List available migrations first + files, err := afero.ReadDir(fs, migrationsDir) + if err != nil { + return nil, fmt.Errorf("failed to read migrations directory: %w", err) + } + + var migrations []string + for _, f := range files { + if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") { + migrations = append(migrations, f.Name()) + } + } + + if len(migrations) == 0 { + return nil, fmt.Errorf("no migrations found in %s", migrationsDir) + } + + fmt.Printf("Found migrations to apply:\n%s\n", strings.Join(migrations, "\n")) + + // Create a copy of DbConfig with values from environment + migrationConfig := flags.DbConfig + migrationConfig.Database = os.Getenv("PGDATABASE") + if migrationConfig.Database == "" { + migrationConfig.Database = "postgres" + } + migrationConfig.Host = os.Getenv("PGHOST") + if migrationConfig.Host == "" { + migrationConfig.Host = "localhost" + } + migrationConfig.Port = uint16(parsePort(os.Getenv("PGPORT"), 54322)) + migrationConfig.User = os.Getenv("PGUSER") + if migrationConfig.User == "" { + migrationConfig.User = "postgres" + } + migrationConfig.Password = os.Getenv("PGPASSWORD") + if migrationConfig.Password == "" { + migrationConfig.Password = "postgres" + } + + connStr := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable", + migrationConfig.User, + migrationConfig.Password, + migrationConfig.Host, + migrationConfig.Port, + migrationConfig.Database) + + fmt.Printf("Attempting to connect to: %s\n", strings.Replace(connStr, migrationConfig.Password, "****", 1)) + + if err := applyMigrationsDirectly(migrations, connStr); err != nil { + return nil, fmt.Errorf("error applying migrations: %w", err) + } + + return &ToolResponse{ + Result: fmt.Sprintf("Successfully applied %d migrations to local database", len(migrations)), + Success: true, + }, nil + + case "list_migrations": + fs := afero.NewOsFs() + migrationsDir, err := findMigrationsDir(fs) + if err != nil { + return nil, err + } + files, err := afero.ReadDir(fs, migrationsDir) + if err != nil { + return nil, fmt.Errorf("failed to read migrations directory: %w", err) + } + + var migrations []string + for _, f := range files { + if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") { + migrations = append(migrations, f.Name()) + } + } + + return &ToolResponse{ + Result: fmt.Sprintf("Available migrations:\n%s", strings.Join(migrations, "\n")), + Success: true, + }, nil + + case "read_migration": + var params struct { + Version string `json:"version"` + } + if err := json.Unmarshal([]byte(call.Arguments), ¶ms); err != nil { + return nil, err + } + + fs := afero.NewOsFs() + migrationsDir, err := findMigrationsDir(fs) + if err != nil { + return nil, err + } + matchingFile, err := findMigrationFile(fs, params.Version) + if err != nil { + return nil, err + } + + migrationPath := fmt.Sprintf("%s/%s", migrationsDir, matchingFile) + content, err := afero.ReadFile(fs, migrationPath) + if err != nil { + return nil, fmt.Errorf("failed to read migration file: %w", err) + } + + return &ToolResponse{ + Result: string(content), + Success: true, + }, nil + + case "get_cli_help": + var params struct { + Command string `json:"command"` + } + if err := json.Unmarshal([]byte(call.Arguments), ¶ms); err != nil { + return nil, err + } + + // Create a buffer to capture the help output + var buf strings.Builder + + // Get the command to show help for + cmd := rootCmd + args := []string{} + + if params.Command != "" { + args = append(args, params.Command, "--help") + } else { + args = append(args, "--help") + } + + // Save the original output and error output + oldOut := cmd.OutOrStdout() + oldErr := cmd.ErrOrStderr() + + // Redirect output to our buffer + cmd.SetOut(&buf) + cmd.SetErr(&buf) + + // Execute the command with --help + cmd.SetArgs(args) + cmd.Execute() + + // Restore the original output + cmd.SetOut(oldOut) + cmd.SetErr(oldErr) + + return &ToolResponse{ + Result: buf.String(), + Success: true, + }, nil + + default: + return nil, fmt.Errorf("unknown function: %s", call.Name) + } +} + +func parsePort(portStr string, defaultPort uint16) uint16 { + if portStr == "" { + return defaultPort + } + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return defaultPort + } + return uint16(port) +} + +func findMigrationFile(fs afero.Fs, timestamp string) (string, error) { + files, err := afero.ReadDir(fs, "supabase/migrations") + if err != nil { + return "", fmt.Errorf("failed to read migrations directory: %w", err) + } + + // Clean the timestamp - remove .sql suffix if present + timestamp = strings.TrimSuffix(timestamp, ".sql") + + var bestMatch string + var longestMatch int + + for _, f := range files { + if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") { + name := strings.TrimSuffix(f.Name(), ".sql") + // Find the common prefix length + prefixLen := 0 + for i := 0; i < len(timestamp) && i < len(name) && timestamp[i] == name[i]; i++ { + prefixLen++ + } + // Update if this is the longest match so far + if prefixLen > longestMatch { + longestMatch = prefixLen + bestMatch = f.Name() + } + } + } + + if bestMatch == "" { + return "", fmt.Errorf("no migration file found matching: %s", timestamp) + } + + return bestMatch, nil +} + +func findMigrationsDir(fs afero.Fs) (string, error) { + // Common migration directory patterns + patterns := []string{ + "supabase/migrations", + "migrations", + "db/migrations", + "database/migrations", + } + + // Try exact matches first + for _, path := range patterns { + exists, _ := afero.DirExists(fs, path) + if exists { + return path, nil + } + } + + // If no exact match, look for longest matching prefix + files, err := afero.ReadDir(fs, ".") + if err != nil { + return defaultMigrationsPath, nil // Fall back to default + } + + var bestMatch string + var longestMatch int + for _, f := range files { + if f.IsDir() { + // Look for directories containing "migration" + if strings.Contains(strings.ToLower(f.Name()), "migration") { + if len(f.Name()) > longestMatch { + longestMatch = len(f.Name()) + bestMatch = f.Name() + } + } + } + } + + if bestMatch != "" { + return bestMatch, nil + } + + return defaultMigrationsPath, nil +} + +func applyMigrationsDirectly(files []string, connStr string) error { + db, err := sql.Open("postgres", connStr) + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + defer db.Close() + + // Create version tracking table if it doesn't exist + _, err = db.Exec(` + CREATE SCHEMA IF NOT EXISTS supabase_migrations; + CREATE TABLE IF NOT EXISTS supabase_migrations.schema_migrations ( + version text PRIMARY KEY, + applied_at timestamptz DEFAULT now() + ) + `) + if err != nil { + return fmt.Errorf("failed to create migrations table: %w", err) + } + + // Apply each migration in a transaction + for _, file := range files { + var exists bool + err = db.QueryRow("SELECT EXISTS(SELECT 1 FROM supabase_migrations.schema_migrations WHERE version = $1)", file).Scan(&exists) + if err != nil { + return fmt.Errorf("failed to check migration status: %w", err) + } + if exists { + fmt.Printf("Skipping %s (already applied)\n", file) + continue + } + + fmt.Printf("Applying %s...\n", file) + content, err := afero.ReadFile(afero.NewOsFs(), "supabase/migrations/"+file) + if err != nil { + return fmt.Errorf("failed to read migration file: %w", err) + } + + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) + } + + if _, err := tx.Exec(string(content)); err != nil { + tx.Rollback() + return fmt.Errorf("failed to apply migration %s: %w", file, err) + } + + if _, err := tx.Exec("INSERT INTO supabase_migrations.schema_migrations (version) VALUES ($1)", file); err != nil { + tx.Rollback() + return fmt.Errorf("failed to record migration %s: %w", file, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit migration %s: %w", file, err) + } + fmt.Printf("Successfully applied %s\n", file) + } + return nil +} diff --git a/cmd/assistant_ui.go b/cmd/assistant_ui.go new file mode 100644 index 000000000..978393d05 --- /dev/null +++ b/cmd/assistant_ui.go @@ -0,0 +1,60 @@ +package cmd + +import ( + "fmt" + "strings" + + "github.com/fatih/color" +) + +// UI elements shared across assistant commands +var ( + // Initialize color outputs for PowerShell compatibility + bannerColor = color.New(color.FgGreen) // Color for banners and prompts + assistantColor = color.New(color.FgCyan) // Color for assistant responses + commandColor = color.New(color.FgYellow) // Color for command names + separatorColor = color.New(color.FgHiBlack) // For the separator lines + errorColor = color.New(color.FgRed) // Color for error messages + toolColor = color.RGB(0xF8, 0x83, 0x79) // Coral color for tool output (F88379) + + separator = separatorColor.Sprint(strings.Repeat("=", 50)) // Create a string of 50 equals signs + toolSeparator = separatorColor.Sprint(strings.Repeat("-", 30)) // Shorter separator for tools + prompt = separatorColor.Sprint("\n> ") // For user input + + helpText = fmt.Sprintf(`%s +Welcome to the DNA Assistant. Available commands: + %s - End the chat session + %s - Show this help message + %s - Clear chat history + +%s +%s +%s`, + separator, + commandColor.Sprint("exit"), + commandColor.Sprint("help"), + commandColor.Sprint("clear"), + bannerColor.Sprint("Type your questions about database design and normalization."), + assistantColor.Sprint("Press Ctrl+C at any time to exit."), + separator) +) + +// handleChatResponse formats and prints chat responses +func handleChatResponse(response string) { + switch response { + case "exit": + fmt.Fprintln(color.Output, separator) + bannerColor.Fprintln(color.Output, "Goodbye! Feel free to return if you need more database design assistance.") + fmt.Fprintln(color.Output, separator) + case "help": + fmt.Fprint(color.Output, helpText) + case "clear": + fmt.Fprintln(color.Output, separator) + bannerColor.Sprint("Chat history cleared.") + fmt.Fprintln(color.Output, separator) + default: + fmt.Fprintln(color.Output, separator) + assistantColor.Fprintln(color.Output, response) + fmt.Fprintln(color.Output, separator) + } +} diff --git a/cmd/bootstrap.go b/cmd/bootstrap.go index fd9f74912..a3688c9ce 100644 --- a/cmd/bootstrap.go +++ b/cmd/bootstrap.go @@ -23,10 +23,9 @@ var ( } bootstrapCmd = &cobra.Command{ - GroupID: groupQuickStart, - Use: "bootstrap [template]", - Short: "Bootstrap a Supabase project from a starter template", - Args: cobra.MaximumNArgs(1), + Use: "bootstrap", + Short: "Bootstrap your project", + GroupID: "quick-start", RunE: func(cmd *cobra.Command, args []string) error { ctx, _ := signal.NotifyContext(cmd.Context(), os.Interrupt) if !viper.IsSet("WORKDIR") { diff --git a/cmd/network_bans.go b/cmd/network_bans.go new file mode 100644 index 000000000..6567774db --- /dev/null +++ b/cmd/network_bans.go @@ -0,0 +1,15 @@ +package cmd + +import ( + "github.com/spf13/cobra" +) + +var networkBansCmd = &cobra.Command{ + Use: "network-bans", + Short: "Manage network bans", + GroupID: "management-api", +} + +func init() { + rootCmd.AddCommand(networkBansCmd) +} diff --git a/cmd/root.go b/cmd/root.go index 35540a897..89456cada 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -17,7 +17,6 @@ import ( "github.com/spf13/viper" "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/flags" - "golang.org/x/mod/semver" ) const ( @@ -83,8 +82,10 @@ var ( createTicket bool rootCmd = &cobra.Command{ - Use: "supabase", - Short: "Supabase CLI " + utils.Version, + Use: "dna", + Short: "DNA CLI - Database Normalization Assistant", + Long: `DNA CLI helps you analyze and normalize your database schemas. +It provides tools for schema analysis, normalization suggestions, and best practices.`, Version: utils.Version, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { if IsExperimental(cmd) && !viper.GetBool("EXPERIMENTAL") { @@ -131,20 +132,9 @@ var ( ) func Execute() { - defer recoverAndExit() if err := rootCmd.Execute(); err != nil { - panic(err) - } - // Check upgrade last because --version flag is initialised after execute - version, err := checkUpgrade(rootCmd.Context(), afero.NewOsFs()) - if err != nil { - fmt.Fprintln(utils.GetDebugLogger(), err) - } - if semver.Compare(version, "v"+utils.Version) > 0 { - fmt.Fprintln(os.Stderr, suggestUpgrade(version)) - } - if len(utils.CmdSuggestion) > 0 { - fmt.Fprintln(os.Stderr, utils.CmdSuggestion) + fmt.Fprintln(os.Stderr, err) + os.Exit(1) } } @@ -219,27 +209,33 @@ func recoverAndExit() { os.Exit(1) } -func init() { - cobra.OnInitialize(func() { - viper.SetEnvPrefix("SUPABASE") - viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) - viper.AutomaticEnv() - }) +// Move these to the top, before any commands are defined +var ( + quickStartGroup = &cobra.Group{ + ID: "quick-start", + Title: "Quick Start Commands", + } + + localDevGroup = &cobra.Group{ + ID: "local-dev", + Title: "Local Development Commands", + } - flags := rootCmd.PersistentFlags() - flags.Bool("debug", false, "output debug logs to stderr") - flags.String("workdir", "", "path to a Supabase project directory") - flags.Bool("experimental", false, "enable experimental features") - flags.String("network-id", "", "use the specified docker network instead of a generated one") - flags.Var(&utils.OutputFormat, "output", "output format of status variables") - flags.Var(&utils.DNSResolver, "dns-resolver", "lookup domain names using the specified resolver") - flags.BoolVar(&createTicket, "create-ticket", false, "create a support ticket for any CLI error") - cobra.CheckErr(viper.BindPFlags(flags)) + managementAPIGroup = &cobra.Group{ + ID: "management-api", + Title: "Management API Commands", + } +) + +func init() { + // Add groups first + rootCmd.AddGroup(quickStartGroup) + rootCmd.AddGroup(localDevGroup) + rootCmd.AddGroup(managementAPIGroup) - rootCmd.SetVersionTemplate("{{.Version}}\n") - rootCmd.AddGroup(&cobra.Group{ID: groupQuickStart, Title: "Quick Start:"}) - rootCmd.AddGroup(&cobra.Group{ID: groupLocalDev, Title: "Local Development:"}) - rootCmd.AddGroup(&cobra.Group{ID: groupManagementAPI, Title: "Management APIs:"}) + // Then add commands with group + assistantCmd.GroupID = quickStartGroup.ID + rootCmd.AddCommand(assistantCmd) } // instantiate new rootCmd is a bit tricky with cobra, but it can be done later with the following diff --git a/go.mod b/go.mod index 2f8675351..a94350ccd 100644 --- a/go.mod +++ b/go.mod @@ -220,7 +220,7 @@ require ( github.com/mattn/go-runewidth v0.0.16 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/mgechev/revive v1.6.1 // indirect - github.com/microcosm-cc/bluemonday v1.0.25 // indirect + github.com/microcosm-cc/bluemonday v1.0.26 // indirect github.com/miekg/pkcs11 v1.1.1 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect @@ -265,6 +265,7 @@ require ( github.com/sahilm/fuzzy v0.1.1-0.20230530133925-c48e322e2a8f // indirect github.com/sanposhiho/wastedassign/v2 v2.1.0 // indirect github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 // indirect + github.com/sashabaranov/go-openai v1.37.0 // indirect github.com/sashamelentyev/interfacebloat v1.1.0 // indirect github.com/sashamelentyev/usestdlibvars v1.28.0 // indirect github.com/securego/gosec/v2 v2.22.1 // indirect @@ -286,6 +287,7 @@ require ( github.com/theupdateframework/notary v0.7.0 // indirect github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3 // indirect github.com/timonwong/loggercheck v0.10.1 // indirect + github.com/tmc/langchaingo v0.1.13 // indirect github.com/tomarrell/wrapcheck/v2 v2.10.0 // indirect github.com/tommy-muehle/go-mnd/v2 v2.5.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect diff --git a/go.sum b/go.sum index 9eed1ad36..8e07754fa 100644 --- a/go.sum +++ b/go.sum @@ -690,6 +690,8 @@ github.com/mgechev/revive v1.6.1 h1:ncK0ZCMWtb8GXwVAmk+IeWF2ULIDsvRxSRfg5sTwQ2w= github.com/mgechev/revive v1.6.1/go.mod h1:/2tfHWVO8UQi/hqJsIYNEKELi+DJy/e+PQpLgTB1v88= github.com/microcosm-cc/bluemonday v1.0.25 h1:4NEwSfiJ+Wva0VxN5B8OwMicaJvD8r9tlJWm9rtloEg= github.com/microcosm-cc/bluemonday v1.0.25/go.mod h1:ZIOjCQp1OrzBBPIJmfX4qDYFuhU02nx4bn030ixfHLE= +github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58= +github.com/microcosm-cc/bluemonday v1.0.26/go.mod h1:JyzOCs9gkyQyjs+6h10UEVSe02CGwkhd72Xdqh78TWs= github.com/miekg/pkcs11 v1.0.2/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/miekg/pkcs11 v1.1.1 h1:Ugu9pdy6vAYku5DEpVWVFPYnzV+bxB+iRdbuFSu7TvU= github.com/miekg/pkcs11 v1.1.1/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= @@ -855,6 +857,8 @@ github.com/sanposhiho/wastedassign/v2 v2.1.0 h1:crurBF7fJKIORrV85u9UUpePDYGWnwvv github.com/sanposhiho/wastedassign/v2 v2.1.0/go.mod h1:+oSmSC+9bQ+VUAxA66nBb0Z7N8CK7mscKTDYC6aIek4= github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 h1:PKK9DyHxif4LZo+uQSgXNqs0jj5+xZwwfKHgph2lxBw= github.com/santhosh-tekuri/jsonschema/v6 v6.0.1/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU= +github.com/sashabaranov/go-openai v1.37.0 h1:hQQowgYm4OXJ1Z/wTrE+XZaO20BYsL0R3uRPSpfNZkY= +github.com/sashabaranov/go-openai v1.37.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sashamelentyev/interfacebloat v1.1.0 h1:xdRdJp0irL086OyW1H/RTZTr1h/tMEOsumirXcOJqAw= github.com/sashamelentyev/interfacebloat v1.1.0/go.mod h1:+Y9yU5YdTkrNvoX0xHc84dxiN1iBi9+G8zZIhPVoNjQ= github.com/sashamelentyev/usestdlibvars v1.28.0 h1:jZnudE2zKCtYlGzLVreNp5pmCdOxXUzwsMDBkR21cyQ= @@ -950,6 +954,8 @@ github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3 h1:y4mJRFlM6fUyP github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3/go.mod h1:mkjARE7Yr8qU23YcGMSALbIxTQ9r9QBVahQOBRfU460= github.com/timonwong/loggercheck v0.10.1 h1:uVZYClxQFpw55eh+PIoqM7uAOHMrhVcDoWDery9R8Lg= github.com/timonwong/loggercheck v0.10.1/go.mod h1:HEAWU8djynujaAVX7QI65Myb8qgfcZ1uKbdpg3ZzKl8= +github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA= +github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg= github.com/tomarrell/wrapcheck/v2 v2.10.0 h1:SzRCryzy4IrAH7bVGG4cK40tNUhmVmMDuJujy4XwYDg= github.com/tomarrell/wrapcheck/v2 v2.10.0/go.mod h1:g9vNIyhb5/9TQgumxQyOEqDHsmGYcGsVMOx/xGkqdMo= github.com/tommy-muehle/go-mnd/v2 v2.5.1 h1:NowYhSdyE/1zwK9QCLeRb6USWdoif80Ie+v+yU8u1Zw= diff --git a/internal/dna/assistant/chat.go b/internal/dna/assistant/chat.go new file mode 100644 index 000000000..23c6a2b17 --- /dev/null +++ b/internal/dna/assistant/chat.go @@ -0,0 +1,86 @@ +package assistant + +import ( + "bufio" + "context" + "fmt" + "os" + "strings" + + "github.com/supabase/cli/internal/dna/langchain" + "github.com/tmc/langchaingo/llms/openai" +) + +func printWelcomeMessage() { + fmt.Println("Welcome to DNA Assistant! Available commands:") + fmt.Println(" exit - Exit the chat session") + fmt.Println(" help - Show this help message") + fmt.Println(" clear - Clear the screen") + fmt.Println("\nType your questions about database design, normalization, or Supabase.") + fmt.Println("I'll use my knowledge base to provide detailed answers.") + fmt.Println("\nType 'exit' when you're done.") +} + +func handleCommand(cmd string) bool { + switch cmd { + case "exit": + fmt.Println("Goodbye!") + return true + case "help": + printWelcomeMessage() + case "clear": + fmt.Print("\033[H\033[2J") + } + return false +} + +func Chat() error { + apiKey := os.Getenv("DNA_API_KEY") + if apiKey == "" { + return fmt.Errorf("DNA_API_KEY environment variable not set") + } + + // Initialize OpenAI LLM + llm, err := openai.New() + if err != nil { + return fmt.Errorf("failed to initialize OpenAI: %w", err) + } + + // Initialize RAG system + rag, err := langchain.NewRAG(apiKey) + if err != nil { + return fmt.Errorf("failed to initialize RAG: %w", err) + } + + printWelcomeMessage() + + scanner := bufio.NewScanner(os.Stdin) + ctx := context.Background() + + for { + fmt.Print("\nYou: ") + if !scanner.Scan() { + break + } + + input := strings.TrimSpace(scanner.Text()) + if input == "" { + continue + } + + if handleCommand(input) { + break + } + + // Use RAG to get response + response, err := rag.Query(ctx, input, llm) + if err != nil { + fmt.Printf("Error: %v\n", err) + continue + } + + fmt.Printf("\nAssistant: %s\n", response) + } + + return scanner.Err() +} diff --git a/internal/dna/config/config.go b/internal/dna/config/config.go new file mode 100644 index 000000000..1798f4b1a --- /dev/null +++ b/internal/dna/config/config.go @@ -0,0 +1,107 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/BurntSushi/toml" + "github.com/spf13/afero" +) + +// Config holds the DNA assistant configuration +type Config struct { + APIKey string `toml:"api_key"` + Provider string `toml:"provider"` + Model string `toml:"model"` + Temperature float32 `toml:"temperature"` +} + +const ( + // DefaultConfigPath is the default location for the DNA config file + DefaultConfigPath = "dna.config.toml" +) + +// Default configuration values +var DefaultConfig = Config{ + Provider: "openai", + Model: "gpt-4", + Temperature: 0.7, +} + +// Load reads the configuration from the config file and environment variables +func Load(fs afero.Fs) (*Config, error) { + config := DefaultConfig + + // Try to load from config file + configPath := getConfigPath() + if exists, _ := afero.Exists(fs, configPath); exists { + if _, err := toml.DecodeFile(configPath, &config); err != nil { + return nil, fmt.Errorf("error reading config file: %w", err) + } + } + + // Override with environment variables if set + if apiKey := os.Getenv("DNA_API_KEY"); apiKey != "" { + config.APIKey = apiKey + } + if provider := os.Getenv("DNA_PROVIDER"); provider != "" { + config.Provider = provider + } + if model := os.Getenv("DNA_MODEL"); model != "" { + config.Model = model + } + + // Validate config + if err := config.Validate(); err != nil { + return nil, err + } + + return &config, nil +} + +// Save writes the configuration to the config file +func (c *Config) Save(fs afero.Fs) error { + configPath := getConfigPath() + + // Create directory if it doesn't exist + configDir := filepath.Dir(configPath) + if err := fs.MkdirAll(configDir, 0755); err != nil { + return fmt.Errorf("error creating config directory: %w", err) + } + + // Create or truncate config file + f, err := fs.Create(configPath) + if err != nil { + return fmt.Errorf("error creating config file: %w", err) + } + defer f.Close() + + // Write config + if err := toml.NewEncoder(f).Encode(c); err != nil { + return fmt.Errorf("error writing config file: %w", err) + } + + return nil +} + +// Validate checks if the configuration is valid +func (c *Config) Validate() error { + if c.APIKey == "" { + return fmt.Errorf("API key is required. Set it in %s or use DNA_API_KEY environment variable", DefaultConfigPath) + } + if c.Provider == "" { + return fmt.Errorf("provider is required") + } + if c.Model == "" { + return fmt.Errorf("model is required") + } + return nil +} + +func getConfigPath() string { + if path := os.Getenv("DNA_CONFIG_PATH"); path != "" { + return path + } + return DefaultConfigPath +} diff --git a/internal/dna/langchain/provider.go b/internal/dna/langchain/provider.go new file mode 100644 index 000000000..1becfc7a6 --- /dev/null +++ b/internal/dna/langchain/provider.go @@ -0,0 +1,61 @@ +package langchain + +import ( + "context" + "fmt" + "time" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/openai" + "github.com/tmc/langchaingo/schema" +) + +// Metrics tracks usage and performance metrics +type Metrics struct { + Tokens int + PromptTokens int + ResponseTime time.Duration + TotalRequests int + Errors int +} + +// Provider wraps LangChain functionality with metrics +type Provider struct { + llm llms.LLM + metrics *Metrics +} + +// NewProvider creates a new LangChain provider with the given API key +func NewProvider(apiKey string) (*Provider, error) { + llm, err := openai.NewChat(openai.WithToken(apiKey)) + if err != nil { + return nil, fmt.Errorf("failed to create LangChain client: %w", err) + } + + return &Provider{ + llm: llm, + metrics: &Metrics{}, + }, nil +} + +// GetResponse generates a response using LangChain with metrics +func (p *Provider) GetResponse(ctx context.Context, messages []schema.ChatMessage) (string, error) { + start := time.Now() + p.metrics.TotalRequests++ + + completion, err := p.llm.Call(ctx, messages) + if err != nil { + p.metrics.Errors++ + return "", fmt.Errorf("LangChain call failed: %w", err) + } + + p.metrics.ResponseTime = time.Since(start) + // Note: Token counting would be implemented here + + return completion, nil +} + +// GetMetrics returns the current metrics +func (p *Provider) GetMetrics() Metrics { + return *p.metrics +} diff --git a/internal/dna/langchain/rag.go b/internal/dna/langchain/rag.go new file mode 100644 index 000000000..8191fd4a1 --- /dev/null +++ b/internal/dna/langchain/rag.go @@ -0,0 +1,431 @@ +package langchain + +import ( + "context" + "fmt" + "strings" + + "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/embeddings/openai" + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/schema" + "github.com/tmc/langchaingo/vectorstores" +) + +// Document represents a piece of documentation +type Document struct { + Content string + Metadata map[string]interface{} +} + +// Source represents a knowledge source +type Source struct { + Title string + URL string + Author string + Date string +} + +// DocumentWithSource adds source attribution to a document +type DocumentWithSource struct { + Document + Source Source +} + +// RAG handles retrieval augmented generation +type RAG struct { + store vectorstores.VectorStore + embedder embeddings.Embedder + documents []Document + chunkSize int + numResults int +} + +// PostgreSQLDoc represents an official PostgreSQL documentation section +type PostgreSQLDoc struct { + Source + Chapter string + Section string + Version string +} + +// NewRAG creates a new RAG instance +func NewRAG(apiKey string) (*RAG, error) { + embedder := openai.NewEmbedder(apiKey) + + // Initialize with some basic documentation + docs := []Document{ + { + Content: `Database normalization is the process of structuring a database to reduce data redundancy and improve data integrity. + +First Normal Form (1NF): +- Each table cell should contain a single value (atomic values) +- Each record needs to be unique +- No repeating groups or arrays +- Example: Instead of storing comma-separated categories in a 'tags' column, create a separate tags table`, + Metadata: map[string]interface{}{ + "type": "normalization", + "form": "1NF", + }, + }, + { + Content: `Second Normal Form (2NF): +- Must be in 1NF +- All non-key attributes are fully dependent on the primary key +- No partial dependencies +- Example: In a product_orders table, if order_date depends only on order_id (not product_id), it should be in a separate orders table`, + Metadata: map[string]interface{}{ + "type": "normalization", + "form": "2NF", + }, + }, + { + Content: `Third Normal Form (3NF): +- Must be in 2NF +- No transitive dependencies +- All fields must depend directly on the primary key +- Example: If product_category_name depends on category_id which depends on product_id, move category data to a separate table`, + Metadata: map[string]interface{}{ + "type": "normalization", + "form": "3NF", + }, + }, + { + Content: `Complex Relationships in PostgreSQL: +For handling complex relationships like products belonging to multiple categories: + +1. Use bridge tables for many-to-many relationships: +CREATE TABLE products ( + product_id SERIAL PRIMARY KEY, + name VARCHAR(255), + base_price DECIMAL(10,2) +); + +CREATE TABLE categories ( + category_id SERIAL PRIMARY KEY, + name VARCHAR(255) +); + +CREATE TABLE product_categories ( + product_id INTEGER REFERENCES products(product_id), + category_id INTEGER REFERENCES categories(category_id), + PRIMARY KEY (product_id, category_id) +);`, + Metadata: map[string]interface{}{ + "type": "implementation", + "topic": "relationships", + }, + }, + { + Content: `Performance Optimization with Normalized Data: +1. Create indexes on frequently joined columns: +CREATE INDEX idx_product_categories_product_id ON product_categories(product_id); +CREATE INDEX idx_product_categories_category_id ON product_categories(category_id); + +2. Use materialized views for complex queries that don't need real-time data: +CREATE MATERIALIZED VIEW product_category_summary AS +SELECT p.name AS product_name, + string_agg(c.name, ', ') AS categories, + p.base_price +FROM products p +JOIN product_categories pc ON p.product_id = pc.product_id +JOIN categories c ON pc.category_id = c.category_id +GROUP BY p.product_id, p.name, p.base_price;`, + Metadata: map[string]interface{}{ + "type": "optimization", + "topic": "performance", + }, + }, + { + Content: `Constraint Management in PostgreSQL: +Use PostgreSQL's constraint system to maintain data integrity: + +1. Check constraints for business rules: +ALTER TABLE products +ADD CONSTRAINT price_check +CHECK (base_price >= 0); + +2. Unique constraints for data uniqueness: +ALTER TABLE categories +ADD CONSTRAINT unique_category_name +UNIQUE (name); + +3. Foreign key constraints for referential integrity: +ALTER TABLE product_categories +ADD CONSTRAINT fk_product +FOREIGN KEY (product_id) +REFERENCES products(product_id) +ON DELETE CASCADE;`, + Metadata: map[string]interface{}{ + "type": "implementation", + "topic": "constraints", + }, + }, + { + Content: `When to Consider Denormalization: +1. When query performance is critical and joins are expensive +2. For read-heavy workloads with infrequent updates +3. When maintaining real-time aggregations +4. For time-series data with specific access patterns + +Note: Always measure performance impact before denormalizing.`, + Metadata: map[string]interface{}{ + "type": "optimization", + "topic": "denormalization", + }, + }, + { + Content: `Supabase Row Level Security (RLS): +- Control access to rows in database tables +- Define policies using SQL +- Automatically applied to all queries +- Essential for multi-tenant applications +- Integrates with Supabase Auth`, + Metadata: map[string]interface{}{ + "type": "supabase", + "feature": "rls", + }, + }, + } + + // Create vector store + store := vectorstores.NewMemory(embedder) + + rag := &RAG{ + store: store, + embedder: embedder, + documents: docs, + chunkSize: 500, + numResults: 3, + } + + // Add documents to store + if err := rag.initializeStore(); err != nil { + return nil, err + } + + return rag, nil +} + +// initializeStore adds all documents to the vector store +func (r *RAG) initializeStore() error { + for _, doc := range r.documents { + chunks := r.chunkText(doc.Content) + for _, chunk := range chunks { + _, err := r.store.AddDocuments(context.Background(), []schema.Document{ + { + PageContent: chunk, + Metadata: doc.Metadata, + }, + }) + if err != nil { + return fmt.Errorf("failed to add document to store: %w", err) + } + } + } + return nil +} + +// chunkText splits text into smaller chunks +func (r *RAG) chunkText(text string) []string { + words := strings.Fields(text) + var chunks []string + for i := 0; i < len(words); i += r.chunkSize { + end := i + r.chunkSize + if end > len(words) { + end = len(words) + } + chunks = append(chunks, strings.Join(words[i:end], " ")) + } + return chunks +} + +// Query searches for relevant documents and augments the prompt +func (r *RAG) Query(ctx context.Context, query string, llm llms.LLM) (string, error) { + // First, search both PostgreSQL and Supabase docs for relevant content + searchErrors := make([]error, 0) + + // Search PostgreSQL docs + if err := r.SearchAndAddToRAG(ctx, query, 3); err != nil { + searchErrors = append(searchErrors, fmt.Errorf("PostgreSQL search failed: %w", err)) + } + + // Search Supabase docs + if err := r.SearchAndAddSupabaseToRAG(ctx, query, 3); err != nil { + searchErrors = append(searchErrors, fmt.Errorf("Supabase search failed: %w", err)) + } + + // Log search errors but continue - we still have our base knowledge + if len(searchErrors) > 0 { + fmt.Println("Warning: Some documentation searches failed:") + for _, err := range searchErrors { + fmt.Printf("- %v\n", err) + } + } + + // Search for relevant documents in our knowledge base + results, err := r.store.SimilaritySearch(ctx, query, r.numResults) + if err != nil { + return "", fmt.Errorf("similarity search failed: %w", err) + } + + // Build augmented prompt + var relevantDocs strings.Builder + for _, doc := range results { + relevantDocs.WriteString(doc.PageContent) + if metadata, ok := doc.Metadata["source"].(Source); ok { + relevantDocs.WriteString(fmt.Sprintf("\n\nSource: %s (%s)", metadata.Title, metadata.URL)) + } + relevantDocs.WriteString("\n\n") + } + + // Create augmented prompt + augmentedPrompt := fmt.Sprintf(`Based on the following documentation: + +%s + +Answer this question: %s + +Provide specific examples and references to the documentation where relevant. If you're referencing Supabase-specific features, make sure to explain how they relate to standard PostgreSQL concepts.`, relevantDocs.String(), query) + + // Get response from LLM + messages := []schema.ChatMessage{ + &schema.SystemMessage{Content: "You are a database design expert focusing on PostgreSQL and Supabase. Explain concepts clearly and provide practical examples."}, + &schema.HumanMessage{Content: augmentedPrompt}, + } + + completion, err := llm.GenerateContent(ctx, messages, llms.WithTemperature(0.7)) + if err != nil { + return "", fmt.Errorf("LLM call failed: %w", err) + } + + return completion.Content, nil +} + +// AddDocument adds a new document to the RAG system +func (r *RAG) AddDocument(ctx context.Context, doc Document) error { + r.documents = append(r.documents, doc) + chunks := r.chunkText(doc.Content) + for _, chunk := range chunks { + _, err := r.store.AddDocuments(ctx, []schema.Document{ + { + PageContent: chunk, + Metadata: doc.Metadata, + }, + }) + if err != nil { + return fmt.Errorf("failed to add document to store: %w", err) + } + } + return nil +} + +// AddKnowledgeSource adds a new knowledge source with proper attribution +func (r *RAG) AddKnowledgeSource(ctx context.Context, content string, source Source, topics []string) error { + // Create metadata with source information and topics + metadata := map[string]interface{}{ + "source": source, + "topics": topics, + } + + // Add the document with source attribution + return r.AddDocument(ctx, Document{ + Content: content, + Metadata: metadata, + }) +} + +// AddPostgreSQLDoc adds official PostgreSQL documentation with proper structure +func (r *RAG) AddPostgreSQLDoc(ctx context.Context, doc PostgreSQLDoc, content string) error { + source := Source{ + Title: fmt.Sprintf("PostgreSQL %s: %s - %s", doc.Version, doc.Chapter, doc.Section), + URL: fmt.Sprintf("https://www.postgresql.org/docs/%s/%s.html", doc.Version, strings.ToLower(strings.ReplaceAll(doc.Chapter, " ", "-"))), + Author: "PostgreSQL Global Development Group", + Date: fmt.Sprintf("PostgreSQL %s", doc.Version), + } + + metadata := map[string]interface{}{ + "source": source, + "type": "postgresql_docs", + "version": doc.Version, + "chapter": doc.Chapter, + "section": doc.Section, + "is_official": true, + } + + return r.AddDocument(ctx, Document{ + Content: content, + Metadata: metadata, + }) +} + +// Example usage: +func ExampleAddTimescaleArticle(ctx context.Context, r *RAG) error { + source := Source{ + Title: "How to Use PostgreSQL for Data Normalization", + URL: "https://www.timescale.com/learn/how-to-use-postgresql-for-data-normalization", + Author: "Timescale", + Date: "2024", + } + + content := `Handling Complex Relationships in PostgreSQL: +When it comes to data normalization in PostgreSQL, managing complex relationships between different entities is one of the most significant challenges. For example, in an e-commerce platform where products can belong to multiple categories and have various attributes: + +Best Practices: +1. Use bridge tables for many-to-many relationships +2. Implement appropriate indexing strategies +3. Use materialized views for complex, frequently-accessed data +4. Apply constraints to maintain data integrity + +Performance Considerations: +- Create indexes on frequently joined columns +- Use materialized views for complex queries that don't need real-time data +- Consider denormalization only when necessary and after measuring performance impact + +When to Consider Denormalization: +1. For read-heavy workloads with infrequent updates +2. When query performance is critical +3. For time-series data with specific access patterns +4. When maintaining real-time aggregations + +Always validate the impact of denormalization through testing and measurement.` + + return r.AddKnowledgeSource(ctx, content, source, []string{ + "normalization", + "relationships", + "performance", + "postgresql", + "denormalization", + }) +} + +func ExampleAddPostgreSQLNormalizationDocs(ctx context.Context, r *RAG) error { + // Add Database Design chapter + doc := PostgreSQLDoc{ + Version: "17", + Chapter: "Database Design", + Section: "Data Modeling", + } + + content := `Database Design Principles in PostgreSQL: + +Table Design Guidelines: +1. Choose the right data types +2. Normalize data appropriately +3. Use constraints to enforce data integrity +4. Consider indexing strategy from the start + +Normalization Guidelines: +- Break down tables to eliminate redundancy +- Ensure each column serves a single purpose +- Use foreign keys to maintain relationships +- Consider the impact on query performance + +PostgreSQL-specific features for better design: +- JSONB for semi-structured data +- Array types when appropriate +- Inheritance for table hierarchies +- Partitioning for large tables` + + return r.AddPostgreSQLDoc(ctx, doc, content) +} diff --git a/internal/dna/langchain/search.go b/internal/dna/langchain/search.go new file mode 100644 index 000000000..14358d53b --- /dev/null +++ b/internal/dna/langchain/search.go @@ -0,0 +1,232 @@ +package langchain + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// SearchResult represents a PostgreSQL documentation search result +type SearchResult struct { + Title string `json:"title"` + URL string `json:"url"` + Content string `json:"content"` + Section string `json:"section"` +} + +// SupabaseSearchResult represents a Supabase documentation search result +type SupabaseSearchResult struct { + Title string `json:"title"` + URL string `json:"url"` + Content string `json:"content"` + Category string `json:"category"` +} + +// SearchPostgresDocs searches PostgreSQL documentation and returns top results +func SearchPostgresDocs(ctx context.Context, query string, limit int) ([]SearchResult, error) { + // Build search URL + baseURL := "https://www.postgresql.org/search/" + params := url.Values{} + params.Add("q", query) + params.Add("u", "/docs/17/") // Search in latest docs + params.Add("fmt", "json") // Request JSON response + params.Add("limit", fmt.Sprintf("%d", limit)) + + // Create request + req, err := http.NewRequestWithContext(ctx, "GET", baseURL+"?"+params.Encode(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Send request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Read response + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + // Parse results + var results []SearchResult + if err := json.Unmarshal(body, &results); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return results, nil +} + +// SearchSupabaseDocs searches Supabase documentation and returns top results +func SearchSupabaseDocs(ctx context.Context, query string, limit int) ([]SupabaseSearchResult, error) { + // Base URLs for Supabase documentation + baseURLs := []string{ + "https://supabase.com/docs/guides/database", + "https://supabase.com/docs/guides/auth", + "https://supabase.com/docs/guides/functions", + "https://supabase.com/docs/guides/realtime", + } + + // Create HTTP client with context + client := &http.Client{} + + var results []SupabaseSearchResult + for _, baseURL := range baseURLs { + // Create request + req, err := http.NewRequestWithContext(ctx, "GET", baseURL, nil) + if err != nil { + continue // Skip if we can't access this section + } + + // Send request + resp, err := client.Do(req) + if err != nil { + continue + } + defer resp.Body.Close() + + // Read response + body, err := io.ReadAll(resp.Body) + if err != nil { + continue + } + + // Extract relevant content (this is a simplified version) + // In a real implementation, we would parse the HTML and extract structured content + category := strings.TrimPrefix(baseURL, "https://supabase.com/docs/guides/") + results = append(results, SupabaseSearchResult{ + Title: fmt.Sprintf("Supabase %s Guide", strings.Title(category)), + URL: baseURL, + Content: string(body), + Category: category, + }) + + if len(results) >= limit { + break + } + } + + return results, nil +} + +// AddSearchResultToRAG adds a search result to the RAG system +func (r *RAG) AddSearchResultToRAG(ctx context.Context, result SearchResult) error { + source := Source{ + Title: result.Title, + URL: result.URL, + Author: "PostgreSQL Global Development Group", + Date: "PostgreSQL 17", + } + + return r.AddKnowledgeSource(ctx, result.Content, source, []string{ + "postgresql_docs", + "search_result", + result.Section, + }) +} + +// AddSupabaseSearchResultToRAG adds a Supabase search result to the RAG system +func (r *RAG) AddSupabaseSearchResultToRAG(ctx context.Context, result SupabaseSearchResult) error { + source := Source{ + Title: result.Title, + URL: result.URL, + Author: "Supabase", + Date: time.Now().Format("2006-01-02"), // Current date as docs are regularly updated + } + + return r.AddKnowledgeSource(ctx, result.Content, source, []string{ + "supabase_docs", + "search_result", + result.Category, + }) +} + +// SearchAndAddToRAG searches PostgreSQL docs and adds results to RAG +func (r *RAG) SearchAndAddToRAG(ctx context.Context, query string, limit int) error { + // Search docs + results, err := SearchPostgresDocs(ctx, query, limit) + if err != nil { + return fmt.Errorf("search failed: %w", err) + } + + // Add each result to RAG + for _, result := range results { + if err := r.AddSearchResultToRAG(ctx, result); err != nil { + return fmt.Errorf("failed to add result to RAG: %w", err) + } + } + + return nil +} + +// SearchAndAddSupabaseToRAG searches Supabase docs and adds results to RAG +func (r *RAG) SearchAndAddSupabaseToRAG(ctx context.Context, query string, limit int) error { + // Search docs + results, err := SearchSupabaseDocs(ctx, query, limit) + if err != nil { + return fmt.Errorf("search failed: %w", err) + } + + // Add each result to RAG + for _, result := range results { + if err := r.AddSupabaseSearchResultToRAG(ctx, result); err != nil { + return fmt.Errorf("failed to add result to RAG: %w", err) + } + } + + return nil +} + +// Example usage: +func ExampleSearchAndAdd(ctx context.Context, r *RAG) error { + // Search for normalization-related docs and add top 3 results + if err := r.SearchAndAddToRAG(ctx, "database normalization best practices", 3); err != nil { + return err + } + + // Search for specific topics + queries := []string{ + "table relationships many-to-many", + "database constraints foreign keys", + "indexing strategy", + } + + for _, query := range queries { + if err := r.SearchAndAddToRAG(ctx, query, 2); err != nil { + return err + } + } + + return nil +} + +func ExampleSearchSupabase(ctx context.Context, r *RAG) error { + // Search for database-related docs and add top 3 results + if err := r.SearchAndAddSupabaseToRAG(ctx, "row level security best practices", 3); err != nil { + return err + } + + // Search for specific topics + queries := []string{ + "foreign key relationships", + "database functions", + "real-time subscriptions", + } + + for _, query := range queries { + if err := r.SearchAndAddSupabaseToRAG(ctx, query, 2); err != nil { + return err + } + } + + return nil +} diff --git a/internal/dna/schema/analyzer.go b/internal/dna/schema/analyzer.go new file mode 100644 index 000000000..6063dd70d --- /dev/null +++ b/internal/dna/schema/analyzer.go @@ -0,0 +1,381 @@ +package schema + +import ( + "context" + "database/sql" + "fmt" + "strings" +) + +// SchemaInfo represents the database schema information +type SchemaInfo struct { + Tables []TableInfo + Extensions []string + Constraints []ConstraintInfo + Indexes []IndexInfo +} + +// TableInfo represents a database table +type TableInfo struct { + Name string + Columns []ColumnInfo + Constraints []ConstraintInfo + Indexes []IndexInfo + RowCount int64 +} + +// ColumnInfo represents a table column +type ColumnInfo struct { + Name string + Type string + IsNullable bool + DefaultValue sql.NullString + IsPrimaryKey bool + IsForeignKey bool + References *ForeignKeyInfo +} + +// ConstraintInfo represents a table constraint +type ConstraintInfo struct { + Name string + Type string + Table string + Columns []string + Definition string +} + +// IndexInfo represents a table index +type IndexInfo struct { + Name string + Table string + Columns []string + IsUnique bool + Definition string +} + +// ForeignKeyInfo represents a foreign key relationship +type ForeignKeyInfo struct { + Table string + Column string + RefTable string + RefColumn string +} + +// Analyzer provides methods to analyze database schema +type Analyzer struct { + db *sql.DB +} + +// NewAnalyzer creates a new schema analyzer +func NewAnalyzer(db *sql.DB) *Analyzer { + return &Analyzer{db: db} +} + +// GetSchemaInfo retrieves complete schema information +func (a *Analyzer) GetSchemaInfo(ctx context.Context) (*SchemaInfo, error) { + info := &SchemaInfo{} + + // Get enabled extensions + extensions, err := a.getExtensions(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get extensions: %w", err) + } + info.Extensions = extensions + + // Get tables + tables, err := a.getTables(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get tables: %w", err) + } + info.Tables = tables + + // Get constraints + constraints, err := a.getConstraints(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get constraints: %w", err) + } + info.Constraints = constraints + + // Get indexes + indexes, err := a.getIndexes(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get indexes: %w", err) + } + info.Indexes = indexes + + return info, nil +} + +func (a *Analyzer) getExtensions(ctx context.Context) ([]string, error) { + query := ` + SELECT extname + FROM pg_extension + WHERE extname != 'plpgsql' + ORDER BY extname; + ` + + rows, err := a.db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + var extensions []string + for rows.Next() { + var ext string + if err := rows.Scan(&ext); err != nil { + return nil, err + } + extensions = append(extensions, ext) + } + + return extensions, rows.Err() +} + +func (a *Analyzer) getTables(ctx context.Context) ([]TableInfo, error) { + query := ` + SELECT + t.table_name, + array_agg(DISTINCT c.column_name) as columns, + array_agg(DISTINCT c.data_type) as types, + array_agg(DISTINCT c.is_nullable) as nullables, + array_agg(DISTINCT c.column_default) as defaults + FROM information_schema.tables t + JOIN information_schema.columns c ON t.table_name = c.table_name + WHERE t.table_schema = 'public' + AND t.table_type = 'BASE TABLE' + GROUP BY t.table_name + ORDER BY t.table_name; + ` + + rows, err := a.db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + var tables []TableInfo + for rows.Next() { + var table TableInfo + var columnNames, columnTypes, columnNullables, columnDefaults []string + + if err := rows.Scan( + &table.Name, + &columnNames, + &columnTypes, + &columnNullables, + &columnDefaults, + ); err != nil { + return nil, err + } + + // Build column info + for i := range columnNames { + col := ColumnInfo{ + Name: columnNames[i], + Type: columnTypes[i], + IsNullable: columnNullables[i] == "YES", + } + if i < len(columnDefaults) && columnDefaults[i] != "" { + col.DefaultValue = sql.NullString{ + String: columnDefaults[i], + Valid: true, + } + } + table.Columns = append(table.Columns, col) + } + + // Get row count + var count int64 + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", table.Name) + if err := a.db.QueryRowContext(ctx, countQuery).Scan(&count); err != nil { + // Log error but continue + fmt.Printf("Warning: Failed to get row count for %s: %v\n", table.Name, err) + } + table.RowCount = count + + tables = append(tables, table) + } + + return tables, rows.Err() +} + +func (a *Analyzer) getConstraints(ctx context.Context) ([]ConstraintInfo, error) { + query := ` + SELECT + c.conname as name, + c.contype as type, + t.relname as table_name, + array_agg(a.attname) as columns, + pg_get_constraintdef(c.oid) as definition + FROM pg_constraint c + JOIN pg_class t ON c.conrelid = t.oid + JOIN pg_namespace n ON t.relnamespace = n.oid + JOIN pg_attribute a ON a.attrelid = t.oid + WHERE n.nspname = 'public' + AND a.attnum = ANY(c.conkey) + GROUP BY c.conname, c.contype, t.relname, c.oid + ORDER BY t.relname, c.conname; + ` + + rows, err := a.db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + var constraints []ConstraintInfo + for rows.Next() { + var c ConstraintInfo + if err := rows.Scan(&c.Name, &c.Type, &c.Table, &c.Columns, &c.Definition); err != nil { + return nil, err + } + constraints = append(constraints, c) + } + + return constraints, rows.Err() +} + +func (a *Analyzer) getIndexes(ctx context.Context) ([]IndexInfo, error) { + query := ` + SELECT + i.relname as name, + t.relname as table_name, + array_agg(a.attname) as columns, + ix.indisunique as is_unique, + pg_get_indexdef(i.oid) as definition + FROM pg_index ix + JOIN pg_class i ON i.oid = ix.indexrelid + JOIN pg_class t ON t.oid = ix.indrelid + JOIN pg_namespace n ON n.oid = t.relnamespace + JOIN pg_attribute a ON a.attrelid = t.oid + WHERE n.nspname = 'public' + AND a.attnum = ANY(ix.indkey) + AND t.relkind = 'r' + GROUP BY i.relname, t.relname, ix.indisunique, i.oid + ORDER BY t.relname, i.relname; + ` + + rows, err := a.db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + var indexes []IndexInfo + for rows.Next() { + var idx IndexInfo + if err := rows.Scan(&idx.Name, &idx.Table, &idx.Columns, &idx.IsUnique, &idx.Definition); err != nil { + return nil, err + } + indexes = append(indexes, idx) + } + + return indexes, rows.Err() +} + +// AnalyzeNormalization checks the schema for normalization issues +func (a *Analyzer) AnalyzeNormalization(ctx context.Context) ([]NormalizationIssue, error) { + var issues []NormalizationIssue + + // Get schema info + schema, err := a.GetSchemaInfo(ctx) + if err != nil { + return nil, err + } + + // Check each table + for _, table := range schema.Tables { + // Check for 1NF violations (non-atomic values) + for _, col := range table.Columns { + if strings.Contains(col.Type, "ARRAY") || strings.Contains(col.Type, "JSON") { + issues = append(issues, NormalizationIssue{ + Level: "1NF", + Table: table.Name, + Column: col.Name, + Issue: "Non-atomic values", + Suggestion: fmt.Sprintf("Consider normalizing %s into a separate table", col.Name), + }) + } + } + + // Check for potential 2NF violations (partial dependencies) + pkColumns := a.getPrimaryKeyColumns(table) + if len(pkColumns) > 1 { + for _, col := range table.Columns { + if !contains(pkColumns, col.Name) { + issues = append(issues, NormalizationIssue{ + Level: "2NF", + Table: table.Name, + Column: col.Name, + Issue: "Potential partial dependency", + Suggestion: "Verify if this column depends on the full primary key", + }) + } + } + } + + // Check for potential 3NF violations (transitive dependencies) + for _, col1 := range table.Columns { + for _, col2 := range table.Columns { + if col1.Name != col2.Name && !col1.IsPrimaryKey && !col2.IsPrimaryKey { + if a.mightHaveTransitiveDependency(ctx, table.Name, col1.Name, col2.Name) { + issues = append(issues, NormalizationIssue{ + Level: "3NF", + Table: table.Name, + Column: fmt.Sprintf("%s -> %s", col1.Name, col2.Name), + Issue: "Potential transitive dependency", + Suggestion: "Consider moving these columns to a separate table", + }) + } + } + } + } + } + + return issues, nil +} + +// NormalizationIssue represents a potential normalization problem +type NormalizationIssue struct { + Level string // 1NF, 2NF, 3NF + Table string + Column string + Issue string + Suggestion string +} + +func (a *Analyzer) getPrimaryKeyColumns(table TableInfo) []string { + var pkColumns []string + for _, col := range table.Columns { + if col.IsPrimaryKey { + pkColumns = append(pkColumns, col.Name) + } + } + return pkColumns +} + +func (a *Analyzer) mightHaveTransitiveDependency(ctx context.Context, table, col1, col2 string) bool { + // This is a simplified check - in reality, you'd need more sophisticated analysis + query := fmt.Sprintf(` + SELECT COUNT(DISTINCT %s) = COUNT(DISTINCT (%s, %s)) + FROM %s + HAVING COUNT(*) > 0; + `, col1, col1, col2, table) + + var hasTransitive bool + err := a.db.QueryRowContext(ctx, query).Scan(&hasTransitive) + if err != nil { + return false + } + return hasTransitive +} + +func contains(slice []string, str string) bool { + for _, s := range slice { + if s == str { + return true + } + } + return false +} diff --git a/internal/jsonschema/schema.go b/internal/jsonschema/schema.go new file mode 100644 index 000000000..f02bcc9c2 --- /dev/null +++ b/internal/jsonschema/schema.go @@ -0,0 +1,14 @@ +package jsonschema + +const ( + String = "string" + Object = "object" +) + +type Definition struct { + Type string `json:"type"` + Properties map[string]Definition `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` + Enum []string `json:"enum,omitempty"` + Description string `json:"description,omitempty"` +} diff --git a/internal/migration/new/new.go b/internal/migration/new/new.go index be232b591..fd98c4d47 100644 --- a/internal/migration/new/new.go +++ b/internal/migration/new/new.go @@ -5,27 +5,30 @@ import ( "io" "os" "path/filepath" + "time" - "github.com/go-errors/errors" + "github.com/pkg/errors" "github.com/spf13/afero" "github.com/supabase/cli/internal/utils" ) -func Run(migrationName string, stdin afero.File, fsys afero.Fs) error { - path := GetMigrationPath(utils.GetCurrentTimestamp(), migrationName) - if err := utils.MkdirIfNotExistFS(fsys, filepath.Dir(path)); err != nil { +func Run(name string, reader io.Reader, fs afero.Fs) error { + timestamp := time.Now().UTC().Format("20060102150405") + filename := fmt.Sprintf("%s_%s.sql", timestamp, name) + path := fmt.Sprintf("supabase/migrations/%s", filename) + + content, err := io.ReadAll(reader) + if err != nil { return err } - f, err := fsys.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + + err = afero.WriteFile(fs, path, content, 0644) if err != nil { - return errors.Errorf("failed to open migration file: %w", err) + return fmt.Errorf("failed to create migration: %w", err) } - defer func() { - fmt.Println("Created new migration at " + utils.Bold(path)) - // File descriptor will always be closed when process quits - _ = f.Close() - }() - return CopyStdinIfExists(stdin, f) + + fmt.Printf("Created new migration at %s\n", path) + return nil } func GetMigrationPath(timestamp, name string) string { diff --git a/main.go b/main.go index 9b54f5ec4..e442afc0c 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,8 @@ package main import ( + "fmt" + "github.com/supabase/cli/cmd" ) @@ -8,5 +10,6 @@ import ( //go:generate go run github.com/deepmap/oapi-codegen/v2/cmd/oapi-codegen --config=pkg/api/client.cfg.yaml api/beta.yaml func main() { + fmt.Println("=== TESTING BUILD ===") cmd.Execute() } diff --git a/scripts/test-dna.ps1 b/scripts/test-dna.ps1 new file mode 100644 index 000000000..dd6b31520 --- /dev/null +++ b/scripts/test-dna.ps1 @@ -0,0 +1,3 @@ +# Run DNA assistant tests +Write-Host "Running DNA assistant tests..." +go test -v ./cmd -run "TestDNAAssistant" -count=1 \ No newline at end of file