fix(vertex): audit fixes for Vertex Service Account feature (#1977)
- Security: force token_uri to Google default, preventing SSRF via crafted service account JSON - Dedup: extract shared getVertexServiceAccountAccessToken() to eliminate ~35 lines of duplication between ClaudeTokenProvider and GeminiTokenProvider - Fix: apply model mapping + Vertex model ID normalization in forward_as_responses and forward_as_chat_completions paths - Fix: exclude service_account from AI Studio endpoint selection (Vertex cannot serve generativelanguage.googleapis.com) - Feature: add model restriction/mapping UI for service_account in EditAccountModal - Dedup: extract VERTEX_LOCATION_OPTIONS to shared constants - i18n: replace all hardcoded Chinese strings in Vertex UI with translation keys
This commit is contained in:
@@ -162,40 +162,5 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
key, err := parseVertexServiceAccountKey(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cacheKey := vertexServiceAccountCacheKey(account, key)
|
||||
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
locked := false
|
||||
if p.tokenCache != nil {
|
||||
var lockErr error
|
||||
locked, lockErr = p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
} else if lockErr != nil {
|
||||
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||
} else {
|
||||
time.Sleep(claudeLockWaitTime)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if p.tokenCache != nil {
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
return accessToken, nil
|
||||
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
|
||||
}
|
||||
|
||||
@@ -61,10 +61,15 @@ func (s *GatewayService) ForwardAsChatCompletions(
|
||||
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
}
|
||||
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
normalized := claude.NormalizeModelID(originalModel)
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
|
||||
@@ -58,10 +58,15 @@ func (s *GatewayService) ForwardAsResponses(
|
||||
// 4. Model mapping
|
||||
mappedModel := originalModel
|
||||
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
|
||||
normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
}
|
||||
} else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
normalized := claude.NormalizeModelID(originalModel)
|
||||
if normalized != originalModel {
|
||||
mappedModel = normalized
|
||||
|
||||
@@ -515,6 +515,10 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
}
|
||||
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
|
||||
return 3
|
||||
case AccountTypeServiceAccount:
|
||||
// Vertex service accounts use aiplatform.googleapis.com, not the AI Studio
|
||||
// endpoint (generativelanguage.googleapis.com), so they cannot serve these requests.
|
||||
return 999
|
||||
default:
|
||||
return 10
|
||||
}
|
||||
|
||||
@@ -172,42 +172,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
key, err := parseVertexServiceAccountKey(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cacheKey := vertexServiceAccountCacheKey(account, key)
|
||||
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
locked := false
|
||||
if p.tokenCache != nil {
|
||||
var lockErr error
|
||||
locked, lockErr = p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
} else if lockErr != nil {
|
||||
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||
} else {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if p.tokenCache != nil {
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
return accessToken, nil
|
||||
return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
|
||||
}
|
||||
|
||||
func GeminiTokenCacheKey(account *Account) string {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
@@ -23,6 +24,7 @@ const (
|
||||
vertexDefaultTokenURL = "https://oauth2.googleapis.com/token"
|
||||
vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
|
||||
vertexServiceAccountCacheSkew = 5 * time.Minute
|
||||
vertexLockWaitTime = 200 * time.Millisecond
|
||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||
)
|
||||
|
||||
@@ -123,9 +125,8 @@ func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error)
|
||||
if strings.TrimSpace(key.ProjectID) == "" {
|
||||
return nil, errors.New("service account json missing project_id")
|
||||
}
|
||||
if strings.TrimSpace(key.TokenURI) == "" {
|
||||
key.TokenURI = vertexDefaultTokenURL
|
||||
}
|
||||
// Always use the well-known Google token endpoint to prevent SSRF via crafted token_uri.
|
||||
key.TokenURI = vertexDefaultTokenURL
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
@@ -141,6 +142,47 @@ func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey
|
||||
return "vertex:service_account:" + fingerprint
|
||||
}
|
||||
|
||||
// getVertexServiceAccountAccessToken obtains an access token for a Vertex service account,
|
||||
// using the shared cache and distributed lock to avoid redundant exchanges.
|
||||
func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCache, account *Account) (string, error) {
|
||||
key, err := parseVertexServiceAccountKey(account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cacheKey := vertexServiceAccountCacheKey(account, key)
|
||||
|
||||
if cache != nil {
|
||||
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
locked := false
|
||||
if cache != nil {
|
||||
var lockErr error
|
||||
locked, lockErr = cache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = cache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
} else if lockErr != nil {
|
||||
slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||
} else {
|
||||
time.Sleep(vertexLockWaitTime)
|
||||
if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if cache != nil {
|
||||
_ = cache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) {
|
||||
now := time.Now()
|
||||
claims := jwt.MapClaims{
|
||||
|
||||
Reference in New Issue
Block a user