feat: 添加 Anthropic 缓存 TTL 注入开关

This commit is contained in:
shaw
2026-04-30 13:38:22 +08:00
parent 094e1171ef
commit 73b872998e
12 changed files with 394 additions and 54 deletions

View File

@@ -62,6 +62,11 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
const (
cacheTTLTarget5m = "5m"
cacheTTLTarget1h = "1h"
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
@@ -4226,6 +4231,87 @@ func enforceCacheControlLimit(body []byte) []byte {
return body
}
// injectAnthropicCacheControlTTL1h 将已有 ephemeral cache_control 块的 ttl 强制写为 1h。
// 仅修改已经存在的 cache_control不新增缓存断点。
func injectAnthropicCacheControlTTL1h(body []byte) []byte {
return forceEphemeralCacheControlTTL(body, cacheTTLTarget1h)
}
func forceEphemeralCacheControlTTL(body []byte, ttl string) []byte {
if len(body) == 0 || ttl == "" {
return body
}
out := body
var paths []string
addPath := func(path string, value gjson.Result) {
cc := value.Get("cache_control")
if !cc.Exists() || cc.Get("type").String() != "ephemeral" {
return
}
if cc.Get("ttl").String() == ttl {
return
}
paths = append(paths, path+".cache_control.ttl")
}
if topCC := gjson.GetBytes(body, "cache_control"); topCC.Exists() && topCC.Get("type").String() == "ephemeral" && topCC.Get("ttl").String() != ttl {
paths = append(paths, "cache_control.ttl")
}
system := gjson.GetBytes(body, "system")
if system.IsArray() {
idx := -1
system.ForEach(func(_, block gjson.Result) bool {
idx++
addPath(fmt.Sprintf("system.%d", idx), block)
return true
})
}
messages := gjson.GetBytes(body, "messages")
if messages.IsArray() {
msgIdx := -1
messages.ForEach(func(_, msg gjson.Result) bool {
msgIdx++
content := msg.Get("content")
if !content.IsArray() {
return true
}
contentIdx := -1
content.ForEach(func(_, block gjson.Result) bool {
contentIdx++
addPath(fmt.Sprintf("messages.%d.content.%d", msgIdx, contentIdx), block)
return true
})
return true
})
}
tools := gjson.GetBytes(body, "tools")
if tools.IsArray() {
idx := -1
tools.ForEach(func(_, tool gjson.Result) bool {
idx++
addPath(fmt.Sprintf("tools.%d", idx), tool)
return true
})
}
for _, path := range paths {
if next, err := sjson.SetBytes(out, path, ttl); err == nil {
out = next
}
}
return out
}
func (s *GatewayService) shouldInjectAnthropicCacheTTL1h(ctx context.Context, account *Account) bool {
if account == nil || !account.IsAnthropicOAuthOrSetupToken() || s == nil || s.settingService == nil {
return false
}
return s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx)
}
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now()
@@ -4385,6 +4471,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
}
if s.shouldInjectAnthropicCacheTTL1h(ctx, account) {
body = injectAnthropicCacheControlTTL1h(body)
}
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
@@ -7225,9 +7315,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
}
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() {
overrideTarget := account.GetCacheTTLOverrideTarget()
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
// 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
@@ -7634,6 +7724,19 @@ func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool {
return true
}
func (s *GatewayService) resolveCacheTTLUsageOverrideTarget(ctx context.Context, account *Account) (string, bool) {
if account == nil {
return "", false
}
if account.IsCacheTTLOverrideEnabled() {
return account.GetCacheTTLOverrideTarget(), true
}
if account.IsAnthropicOAuthOrSetupToken() && s != nil && s.settingService != nil && s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) {
return cacheTTLTarget5m, true
}
return "", false
}
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -7670,9 +7773,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() {
overrideTarget := account.GetCacheTTLOverrideTarget()
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
// 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
// 同步更新 body JSON 中的嵌套 cache_creation 对象
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil {
@@ -8240,10 +8343,11 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
result.Usage.InputTokens = 0
}
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
// 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
applyCacheTTLOverride(&result.Usage, overrideTarget)
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}