|
| 1 | +package bedrock |
| 2 | + |
| 3 | +import ( |
| 4 | + "encoding/json" |
| 5 | + "fmt" |
| 6 | + "net/http" |
| 7 | + "strings" |
| 8 | +) |
| 9 | + |
| 10 | +const jsonContentType = "application/json" |
| 11 | + |
| 12 | +type foundationModel struct { |
| 13 | + ModelID string `json:"modelId"` |
| 14 | + ModelName string `json:"modelName"` |
| 15 | + ProviderName string `json:"providerName"` |
| 16 | + InputModalities []string `json:"inputModalities"` |
| 17 | + OutputModalities []string `json:"outputModalities"` |
| 18 | +} |
| 19 | + |
| 20 | +var mockModels = []foundationModel{ |
| 21 | + {ModelID: "anthropic.claude-3-sonnet-20240229-v1:0", ModelName: "Claude 3 Sonnet", ProviderName: "Anthropic", InputModalities: []string{"TEXT"}, OutputModalities: []string{"TEXT"}}, |
| 22 | + {ModelID: "anthropic.claude-instant-v1", ModelName: "Claude Instant", ProviderName: "Anthropic", InputModalities: []string{"TEXT"}, OutputModalities: []string{"TEXT"}}, |
| 23 | + {ModelID: "amazon.titan-text-express-v1", ModelName: "Titan Text Express", ProviderName: "Amazon", InputModalities: []string{"TEXT"}, OutputModalities: []string{"TEXT"}}, |
| 24 | + {ModelID: "meta.llama2-13b-chat-v1", ModelName: "Llama 2 13B Chat", ProviderName: "Meta", InputModalities: []string{"TEXT"}, OutputModalities: []string{"TEXT"}}, |
| 25 | +} |
| 26 | + |
| 27 | +func Start(port int) error { |
| 28 | + return http.ListenAndServe(fmt.Sprintf(":%d", port), newServer()) |
| 29 | +} |
| 30 | + |
| 31 | +func newServer() http.Handler { |
| 32 | + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 33 | + switch { |
| 34 | + case r.Method == http.MethodPost && r.URL.Path == "/foundation-models": |
| 35 | + listFoundationModels(w) |
| 36 | + case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/foundation-models/"): |
| 37 | + getFoundationModel(w, r) |
| 38 | + case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/model/"): |
| 39 | + invokeModelRoute(w, r) |
| 40 | + default: |
| 41 | + writeError(w, http.StatusNotFound, "ResourceNotFoundException", "Not found") |
| 42 | + } |
| 43 | + }) |
| 44 | +} |
| 45 | + |
| 46 | +func listFoundationModels(w http.ResponseWriter) { |
| 47 | + writeJSON(w, http.StatusOK, map[string]any{"modelSummaries": mockModels}) |
| 48 | +} |
| 49 | + |
| 50 | +func getFoundationModel(w http.ResponseWriter, r *http.Request) { |
| 51 | + modelID := strings.TrimPrefix(r.URL.Path, "/foundation-models/") |
| 52 | + for _, model := range mockModels { |
| 53 | + if model.ModelID == modelID { |
| 54 | + writeJSON(w, http.StatusOK, model) |
| 55 | + return |
| 56 | + } |
| 57 | + } |
| 58 | + writeError(w, http.StatusNotFound, "ResourceNotFoundException", "Model not found") |
| 59 | +} |
| 60 | + |
| 61 | +func invokeModelRoute(w http.ResponseWriter, r *http.Request) { |
| 62 | + path := strings.TrimPrefix(r.URL.Path, "/model/") |
| 63 | + parts := strings.Split(path, "/") |
| 64 | + if len(parts) != 2 { |
| 65 | + writeError(w, http.StatusNotFound, "ResourceNotFoundException", "Not found") |
| 66 | + return |
| 67 | + } |
| 68 | + modelID, action := parts[0], parts[1] |
| 69 | + if action != "invoke" && action != "invoke-with-response-stream" { |
| 70 | + writeError(w, http.StatusNotFound, "ResourceNotFoundException", "Not found") |
| 71 | + return |
| 72 | + } |
| 73 | + invokeModel(w, r, modelID) |
| 74 | +} |
| 75 | + |
| 76 | +func invokeModel(w http.ResponseWriter, r *http.Request, modelID string) { |
| 77 | + var body map[string]any |
| 78 | + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { |
| 79 | + writeError(w, http.StatusBadRequest, "ValidationException", "Invalid JSON body") |
| 80 | + return |
| 81 | + } |
| 82 | + |
| 83 | + switch { |
| 84 | + case strings.HasPrefix(modelID, "anthropic."): |
| 85 | + writeJSON(w, http.StatusOK, map[string]any{ |
| 86 | + "content": []map[string]any{{"type": "text", "text": "Mock response from Claude: "}}, |
| 87 | + "stop_reason": "end_turn", |
| 88 | + "usage": map[string]any{"input_tokens": 10, "output_tokens": 20}, |
| 89 | + }) |
| 90 | + case strings.HasPrefix(modelID, "amazon.titan-"): |
| 91 | + writeJSON(w, http.StatusOK, map[string]any{ |
| 92 | + "results": []map[string]any{{"outputText": "Mock response from Titan: ", "tokenCount": 20, "completionReason": "FINISH"}}, |
| 93 | + }) |
| 94 | + case strings.HasPrefix(modelID, "meta.llama"): |
| 95 | + writeJSON(w, http.StatusOK, map[string]any{ |
| 96 | + "generation": "Mock response from Llama: ", |
| 97 | + "prompt_token_count": 10, |
| 98 | + "generation_token_count": 20, |
| 99 | + "stop_reason": "stop", |
| 100 | + }) |
| 101 | + default: |
| 102 | + writeError(w, http.StatusNotFound, "ResourceNotFoundException", "Model not found") |
| 103 | + } |
| 104 | +} |
| 105 | + |
| 106 | +func writeJSON(w http.ResponseWriter, status int, payload any) { |
| 107 | + w.Header().Set("Content-Type", jsonContentType) |
| 108 | + w.WriteHeader(status) |
| 109 | + _ = json.NewEncoder(w).Encode(payload) |
| 110 | +} |
| 111 | + |
| 112 | +func writeError(w http.ResponseWriter, status int, code, message string) { |
| 113 | + writeJSON(w, status, map[string]any{"__type": code, "message": message}) |
| 114 | +} |
0 commit comments