fix(vertex): audit fixes for Vertex Service Account feature (#1977)

- Security: force token_uri to Google default, preventing SSRF via crafted service account JSON
- Dedup: extract shared getVertexServiceAccountAccessToken() to eliminate ~35 lines of duplication between ClaudeTokenProvider and GeminiTokenProvider
- Fix: apply model mapping + Vertex model ID normalization in forward_as_responses and forward_as_chat_completions paths
- Fix: exclude service_account from AI Studio endpoint selection (Vertex cannot serve generativelanguage.googleapis.com)
- Feature: add model restriction/mapping UI for service_account in EditAccountModal
- Dedup: extract VERTEX_LOCATION_OPTIONS to shared constants
- i18n: replace all hardcoded Chinese strings in Vertex UI with translation keys
This commit is contained in:
shaw
2026-04-29 16:53:09 +08:00
parent 63ef23108c
commit 93d91e20b9
11 changed files with 378 additions and 191 deletions

View File

@@ -162,40 +162,5 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
key, err := parseVertexServiceAccountKey(account)
if err != nil {
return "", err
}
cacheKey := vertexServiceAccountCacheKey(account, key)
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
locked := false
if p.tokenCache != nil {
var lockErr error
locked, lockErr = p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
} else {
time.Sleep(claudeLockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
}
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
if err != nil {
return "", err
}
if p.tokenCache != nil {
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
}

View File

@@ -61,10 +61,15 @@ func (s *GatewayService) ForwardAsChatCompletions(
// 4. Model mapping
mappedModel := originalModel
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
if normalized != originalModel {
mappedModel = normalized
}
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized

View File

@@ -58,10 +58,15 @@ func (s *GatewayService) ForwardAsResponses(
// 4. Model mapping
mappedModel := originalModel
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
if normalized != originalModel {
mappedModel = normalized
}
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized

View File

@@ -515,6 +515,10 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
}
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
return 3
case AccountTypeServiceAccount:
// Vertex service accounts use aiplatform.googleapis.com, not the AI Studio
// endpoint (generativelanguage.googleapis.com), so they cannot serve these requests.
return 999
default:
return 10
}

View File

@@ -172,42 +172,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
key, err := parseVertexServiceAccountKey(account)
if err != nil {
return "", err
}
cacheKey := vertexServiceAccountCacheKey(account, key)
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
locked := false
if p.tokenCache != nil {
var lockErr error
locked, lockErr = p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
} else {
time.Sleep(200 * time.Millisecond)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
}
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
if err != nil {
return "", err
}
if p.tokenCache != nil {
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
}
func GeminiTokenCacheKey(account *Account) string {

View File

@@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"regexp"
@@ -23,6 +24,7 @@ const (
vertexDefaultTokenURL = "https://oauth2.googleapis.com/token"
vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
vertexServiceAccountCacheSkew = 5 * time.Minute
vertexLockWaitTime = 200 * time.Millisecond
vertexAnthropicVersion = "vertex-2023-10-16"
)
@@ -123,9 +125,8 @@ func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error)
if strings.TrimSpace(key.ProjectID) == "" {
return nil, errors.New("service account json missing project_id")
}
if strings.TrimSpace(key.TokenURI) == "" {
key.TokenURI = vertexDefaultTokenURL
}
// Always use the well-known Google token endpoint to prevent SSRF via crafted token_uri.
key.TokenURI = vertexDefaultTokenURL
return &key, nil
}
@@ -141,6 +142,47 @@ func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey
return "vertex:service_account:" + fingerprint
}
// getVertexServiceAccountAccessToken obtains an access token for a Vertex service account,
// using the shared cache and distributed lock to avoid redundant exchanges.
func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCache, account *Account) (string, error) {
key, err := parseVertexServiceAccountKey(account)
if err != nil {
return "", err
}
cacheKey := vertexServiceAccountCacheKey(account, key)
if cache != nil {
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
locked := false
if cache != nil {
var lockErr error
locked, lockErr = cache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = cache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
} else {
time.Sleep(vertexLockWaitTime)
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
}
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
if err != nil {
return "", err
}
if cache != nil {
_ = cache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) {
now := time.Now()
claims := jwt.MapClaims{