diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml
index d7e15377..f8b22ee7 100644
--- a/.github/workflows/backend-ci.yml
+++ b/.github/workflows/backend-ci.yml
@@ -28,6 +28,26 @@ jobs:
working-directory: backend
run: make test-integration
+ frontend:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v6
+ - name: Setup pnpm
+ uses: pnpm/action-setup@v4
+ with:
+ version: 9
+ - name: Setup Node.js
+ uses: actions/setup-node@v6
+ with:
+ node-version: '20'
+ cache: 'pnpm'
+ cache-dependency-path: frontend/pnpm-lock.yaml
+ - name: Install frontend dependencies
+ working-directory: frontend
+ run: pnpm install --frozen-lockfile
+ - name: Frontend typecheck and critical vitest
+ run: make test-frontend
+
golangci-lint:
runs-on: ubuntu-latest
steps:
@@ -46,4 +66,4 @@ jobs:
with:
version: v2.9
args: --timeout=30m
- working-directory: backend
\ No newline at end of file
+ working-directory: backend
diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml
new file mode 100644
index 00000000..67c8d6e9
--- /dev/null
+++ b/.github/workflows/cla.yml
@@ -0,0 +1,59 @@
+name: "CLA Assistant"
+
+on:
+ issue_comment:
+ types: [created]
+ pull_request_target:
+ types: [opened, reopened, closed, synchronize]
+
+permissions:
+ actions: write
+ contents: write
+ pull-requests: write
+ statuses: write
+
+jobs:
+ cla-check:
+ if: |
+ github.event_name == 'issue_comment' ||
+ (github.event_name == 'pull_request_target' && github.event.action != 'closed')
+ runs-on: ubuntu-latest
+ steps:
+ - name: "CLA Assistant"
+ if: |
+ (github.event.comment.body == 'recheck' ||
+ github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') ||
+ github.event_name == 'pull_request_target'
+ uses: contributor-assistant/github-action@v2.6.1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ path-to-signatures: "cla.json"
+ path-to-document: "https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md"
+ branch: "cla-signatures"
+ allowlist: "dependabot[bot],renovate[bot],bot*"
+ lock-pullrequest-aftermerge: false
+ custom-notsigned-prcomment: |
+ Thank you for your contribution! Before we can merge this PR, we need $you to sign our [Contributor License Agreement (CLA)](https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md).
+
+ **To sign**, please reply with the following comment:
+
+ > I have read the CLA Document and I hereby sign the CLA
+
+ You only need to sign once — it will be valid for all your future contributions to this project.
+ custom-pr-sign-comment: "I have read the CLA Document and I hereby sign the CLA"
+ custom-allsigned-prcomment: "All contributors have signed the CLA. ✅"
+
+ cla-lock:
+ if: github.event_name == 'pull_request_target' && github.event.action == 'closed' && github.event.pull_request.merged == true
+ runs-on: ubuntu-latest
+ steps:
+ - name: "Lock merged PR"
+ uses: contributor-assistant/github-action@v2.6.1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ path-to-signatures: "cla.json"
+ path-to-document: "https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md"
+ branch: "cla-signatures"
+ lock-pullrequest-aftermerge: true
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index b729c575..26ed8524 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -246,10 +246,10 @@ jobs:
if [ -n "$DOCKERHUB_USERNAME" ]; then
DOCKER_IMAGE="${DOCKERHUB_USERNAME}/sub2api"
MESSAGE+="# Docker Hub"$'\n'
- MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n'
+ MESSAGE+="docker pull ${DOCKER_IMAGE}:${VERSION}"$'\n'
MESSAGE+="# GitHub Container Registry"$'\n'
fi
- MESSAGE+="docker pull ${GHCR_IMAGE}:${TAG_NAME}"$'\n'
+ MESSAGE+="docker pull ${GHCR_IMAGE}:${VERSION}"$'\n'
MESSAGE+="\`\`\`"$'\n'$'\n'
MESSAGE+="🔗 *相关链接:*"$'\n'
MESSAGE+="• [GitHub Release](https://github.com/${REPO}/releases/tag/${TAG_NAME})"$'\n'
diff --git a/.gitignore b/.gitignore
index b07cd286..ef1d0875 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
docs/claude-relay-service/
+.codex
# ===================
# Go 后端
@@ -121,7 +122,7 @@ scripts
.code-review-state
#openspec/
code-reviews/
-#AGENTS.md
+AGENTS.md
backend/cmd/server/server
deploy/docker-compose.override.yml
.gocache/
@@ -131,10 +132,10 @@ docs/*
!docs/PAYMENT_CN.md
!docs/superpowers/
!docs/design-drafts/
+!docs/ADMIN_PAYMENT_INTEGRATION_API.md
.superpowers/
.serena/
.codex/
frontend/coverage/
aicodex
output/
-
diff --git a/CLA.md b/CLA.md
new file mode 100644
index 00000000..ed0d74b8
--- /dev/null
+++ b/CLA.md
@@ -0,0 +1,73 @@
+# Sub2API Individual Contributor License Agreement (v1.0)
+
+Thank you for your interest in contributing to Sub2API ("the Project"). This Contributor License Agreement ("Agreement") documents the rights granted by contributors to the Project.
+
+By signing this Agreement, you accept and agree to the following terms and conditions for your present and future contributions submitted to the Project.
+
+## 1. Definitions
+
+- **"You" (or "Your")** means the copyright owner or legal entity authorized by the copyright owner that is making this Agreement.
+- **"Contribution"** means any original work of authorship, including any modifications or additions to an existing work, that is intentionally submitted by You to the Project for inclusion in, or documentation of, any of the products owned or managed by the Project. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Project or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Project for the purpose of discussing and improving the Project, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution."
+- **"Project Owner"** means Wesley Liddick, or any individual or legal entity to whom Wesley Liddick has explicitly assigned or transferred ownership of the Project in writing, and their respective successors and assigns.
+
+## 2. Grant of Copyright License
+
+Subject to the terms and conditions of this Agreement, You hereby grant to the Project Owner a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense, and distribute Your Contributions and such derivative works. This license includes, without limitation, the right to sublicense, assign, and transfer these rights to any third party, including without limitation any successor, assignee, or acquiring entity of the Project or the Project Owner, and to use Your Contributions under any license, including proprietary or commercial licenses.
+
+## 3. Moral Rights
+
+To the fullest extent permitted by applicable law, You irrevocably waive and agree not to assert any moral rights (including rights of attribution and integrity) that You may have in Your Contributions, and agree that the Project Owner and its licensees may use, modify, and distribute Your Contributions without attribution or other obligations arising from moral rights.
+
+## 4. Grant of Patent License
+
+Subject to the terms and conditions of this Agreement, You hereby grant to the Project Owner a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer Your Contributions, where such license applies only to those patent claims licensable by You that are necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Project to which such Contribution(s) was submitted.
+
+## 5. Representations and Warranties
+
+You represent and warrant that:
+
+(a) You are legally entitled to grant the above licenses.
+
+(b) If Your employer(s) has rights to intellectual property that You create that includes Your Contributions, You have received permission to make Contributions on behalf of that employer, or that Your employer has waived such rights for Your Contributions to the Project.
+
+(c) Each of Your Contributions is Your original creation, or You have sufficient rights to submit it under the terms of this Agreement. You agree to provide, upon request, reasonable documentation or explanation of any third-party materials included in Your Contributions.
+
+## 6. No Warranty
+
+Your Contributions are provided on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support.
+
+## 7. No Obligation
+
+You understand that the decision to include Your Contribution in any product or project is entirely at the discretion of the Project Owner, and this Agreement does not obligate the Project Owner to use Your Contribution.
+
+## 8. Retention of Rights
+
+You retain ownership of the copyright in Your Contributions. This Agreement does not transfer any copyright or other intellectual property rights from You to the Project Owner. This Agreement only grants the licenses described above.
+
+## 9. Term and Termination
+
+This Agreement shall remain in effect indefinitely. You may terminate this Agreement prospectively by providing written notice to the Project Owner, but such termination shall not affect the licenses granted for Contributions submitted prior to the effective date of termination. The licenses granted herein for Contributions submitted prior to termination are perpetual and irrevocable.
+
+## 10. Electronic Signature
+
+You agree that Your electronic signature (including but not limited to typing a specific phrase in a pull request, issue, or other electronic communication) is legally binding and has the same force and effect as a handwritten signature. You consent to the use of electronic means to enter into this Agreement and acknowledge that this Agreement is enforceable as if executed in a traditional written format.
+
+## 11. General Provisions
+
+**Entire Agreement.** This Agreement constitutes the entire agreement between You and the Project Owner with respect to Your Contributions and supersedes all prior or contemporaneous understandings regarding such subject matter.
+
+**Severability.** If any provision of this Agreement is held to be unenforceable or invalid, that provision will be enforced to the maximum extent possible and the remaining provisions will remain in full force and effect.
+
+**No Waiver.** The failure of the Project Owner to enforce any provision of this Agreement shall not constitute a waiver of that provision or any other provision.
+
+**Amendment.** This Agreement may only be modified by a written instrument signed by both parties. Modifications to this Agreement apply only to Contributions submitted after the modified Agreement is published and accepted by You. Prior Contributions remain governed by the version of the Agreement in effect at the time of submission.
+
+**Notification.** Notices under this Agreement shall be sent to the Project Owner via a GitHub issue on the Project repository. Notices are effective upon receipt.
+
+---
+
+**By signing this CLA, you acknowledge that you have read and understood this Agreement and agree to be bound by its terms.**
+
+To sign, reply in the pull request with:
+
+> I have read the CLA Document and I hereby sign the CLA
diff --git a/LICENSE b/LICENSE
index 7a94ca9d..153d416d 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,21 +1,165 @@
-MIT License
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
-Copyright (c) 2025 Wesley Liddick
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
\ No newline at end of file
diff --git a/Makefile b/Makefile
index fd6a5a9a..d00d0c4f 100644
--- a/Makefile
+++ b/Makefile
@@ -1,4 +1,12 @@
-.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan
+.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-frontend-critical test-datamanagementd secret-scan
+
+FRONTEND_CRITICAL_VITEST := \
+ src/views/auth/__tests__/LinuxDoCallbackView.spec.ts \
+ src/views/auth/__tests__/WechatCallbackView.spec.ts \
+ src/views/user/__tests__/PaymentView.spec.ts \
+ src/views/user/__tests__/PaymentResultView.spec.ts \
+ src/components/user/profile/__tests__/ProfileInfoCard.spec.ts \
+ src/views/admin/__tests__/SettingsView.spec.ts
# 一键编译前后端
build: build-backend build-frontend
@@ -24,6 +32,10 @@ test-backend:
test-frontend:
@pnpm --dir frontend run lint:check
@pnpm --dir frontend run typecheck
+ @$(MAKE) test-frontend-critical
+
+test-frontend-critical:
+ @pnpm --dir frontend exec vitest run $(FRONTEND_CRITICAL_VITEST)
test-datamanagementd:
@cd datamanagement && go test ./...
diff --git a/README.md b/README.md
index 74ab9af2..718730c6 100644
--- a/README.md
+++ b/README.md
@@ -96,6 +96,18 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through BmoPlus - Premium AI Accounts & Top-ups , users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF)
+
+
+Thanks to Bestproxy for sponsoring this project! Bestproxy provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.
+
+
+
+
+Thanks to PatewayAI for sponsoring this project! PatewayAI is a premium model API relay service provider built for heavy AI developers, focused on direct official connections. Offering the full Claude series and Codex series models, 100% sourced directly from official providers — no dilution, no substitution, open to verification. Billing is fully transparent with token-level invoices that can be audited line by line.
+Enterprise-grade high concurrency is also supported, with a dedicated management platform for enterprise clients. Enterprise customers can sign formal contracts and receive invoices. Visit the official website for more details and contact information.
+Register now via this link to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150.
+
+
## Ecosystem
@@ -618,7 +630,9 @@ sub2api/
## License
-MIT License
+This project is licensed under the [GNU Lesser General Public License v3.0](LICENSE) (or later).
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/README_CN.md b/README_CN.md
index c701372c..24600e0e 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -95,6 +95,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充 注册下单的用户,可享GPT 官网订阅一折 的震撼价格!
+
+
+感谢 Bestproxy 赞助了本项目!Bestproxy 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。
+
+
+
+
+感谢 PatewayAI 赞助了本项目!PatewayAI 是一家面向重度 AI 开发者、专注官方直连的高品质模型 API 中转服务商。提供 Claude 全系列与 Codex 系列模型,100% 官方源直供,不掺假不注水,欢迎检验。计费透明,Token 级账单可逐笔核验。
+同时支持企业级高并发,并为企业客户提供了专业的管理平台,企业客户可签订正式合同并开具发票,更多详情进入官网获取联系方式。
+现在通过 此链接 注册即送 $3 试用额度,用户充值低至 6 折,邀请好友双向赠送,邀请奖励可达 $150。
+
+
## 生态项目
@@ -679,7 +691,9 @@ sub2api/
## 许可证
-MIT License
+本项目基于 [GNU 宽通用公共许可证 v3.0](LICENSE)(或更高版本)授权。
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/README_JA.md b/README_JA.md
index 0d4db616..1e89610c 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -95,6 +95,18 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらのBmoPlus AIアカウント専門店/代行チャージ 経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます!
+
+
+Bestproxy のご支援に感謝します!Bestproxy は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。
+
+
+
+
+PatewayAI のご支援に感謝します!PatewayAI は、ヘビーAI開発者向けに公式直結を重視した高品質モデルAPIリレーサービスプロバイダーです。Claude 全シリーズおよび Codex シリーズモデルを提供し、100%公式ソースから直接供給 — 偽りなし、水増しなし、検証歓迎。課金は完全透明で、トークン単位の請求書を1件ずつ監査可能です。
+エンタープライズ級の高同時接続にも対応し、法人顧客向けに専用管理プラットフォームを提供しています。法人顧客は正式な契約を締結し、請求書の発行が可能です。詳細は公式サイトでお問い合わせください。
+こちらのリンク から登録すると、$3 のトライアルクレジットがもらえます。チャージは最大40%オフ、友達紹介で双方にボーナス付与 — 紹介報酬は最大 $150。
+
+
## エコシステム
@@ -617,7 +629,9 @@ sub2api/
## ライセンス
-MIT License
+本プロジェクトは [GNU Lesser General Public License v3.0](LICENSE)(またはそれ以降のバージョン)の下でライセンスされています。
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/assets/partners/logos/bestproxy.png b/assets/partners/logos/bestproxy.png
new file mode 100644
index 00000000..87c58670
Binary files /dev/null and b/assets/partners/logos/bestproxy.png differ
diff --git a/assets/partners/logos/pateway.png b/assets/partners/logos/pateway.png
new file mode 100644
index 00000000..7ca3489a
Binary files /dev/null and b/assets/partners/logos/pateway.png differ
diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go
index 7eabde62..9386678d 100644
--- a/backend/cmd/jwtgen/main.go
+++ b/backend/cmd/jwtgen/main.go
@@ -33,7 +33,7 @@ func main() {
}()
userRepo := repository.NewUserRepository(client, sqlDB)
- authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
+ authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index c29f5f75..025c3166 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.114
+0.1.121
diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go
index 64709b5b..9bfa2717 100644
--- a/backend/cmd/server/wire.go
+++ b/backend/cmd/server/wire.go
@@ -97,6 +97,7 @@ func provideCleanup(
scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
paymentOrderExpiry *service.PaymentOrderExpiryService,
+ channelMonitorRunner *service.ChannelMonitorRunner,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -239,6 +240,12 @@ func provideCleanup(
}
return nil
}},
+ {"ChannelMonitorRunner", func() error {
+ if channelMonitorRunner != nil {
+ channelMonitorRunner.Stop()
+ }
+ return nil
+ }},
}
infraSteps := []cleanupStep{
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 1d39fa1e..40f0191c 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -61,14 +61,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingCache := repository.NewBillingCache(redisClient)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
apiKeyRepository := repository.NewAPIKeyRepository(client, db)
- billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig)
+ userRPMCache := repository.NewUserRPMCache(redisClient)
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
+ billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
- apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
+ apiKeyService := service.ProvideAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig, billingCacheService)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
- authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
+ affiliateRepository := repository.NewAffiliateRepository(client, db)
+ affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService)
+ authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
@@ -79,7 +82,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
- userHandler := handler.NewUserHandler(userService, emailService, emailCache)
+ userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
@@ -90,6 +93,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
announcementReadRepository := repository.NewAnnouncementReadRepository(client)
announcementService := service.NewAnnouncementService(announcementRepository, announcementReadRepository, userRepository, userSubscriptionRepository)
announcementHandler := handler.NewAnnouncementHandler(announcementService)
+ channelMonitorRepository := repository.NewChannelMonitorRepository(client, db)
+ channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor)
+ channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
@@ -104,7 +110,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
privacyClientFactory := providePrivacyClientFactory()
- adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
+ adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
@@ -124,9 +130,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
+ openAI403CounterCache := repository.NewOpenAI403CounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
- rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
+ rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
@@ -136,15 +143,16 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
- oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
+ oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
+ claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
- accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
+ accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, claudeTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
@@ -171,18 +179,26 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
- claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db)
- channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
+ channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
- openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
+ openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
+ encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
+ if err != nil {
+ return nil, err
+ }
+ paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
+ registry := payment.ProvideRegistry()
+ defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
+ paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
+ settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
@@ -210,18 +226,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
channelHandler := admin.NewChannelHandler(channelService, billingService)
- registry := payment.ProvideRegistry()
- encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
- if err != nil {
- return nil, err
- }
- defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
- paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
- paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
- settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
- paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
+ channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
+ channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
+ channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
+ channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler)
+ affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@@ -231,9 +242,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpHandler := handler.NewTotpHandler(totpService)
handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
+ availableChannelHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
- handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, idempotencyCoordinator, idempotencyCleanupService)
+ handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -242,13 +254,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
- opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
+ opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
- v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService)
+ paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
+ channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
+ v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -302,6 +316,7 @@ func provideCleanup(
scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
paymentOrderExpiry *service.PaymentOrderExpiryService,
+ channelMonitorRunner *service.ChannelMonitorRunner,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -443,6 +458,12 @@ func provideCleanup(
}
return nil
}},
+ {"ChannelMonitorRunner", func() error {
+ if channelMonitorRunner != nil {
+ channelMonitorRunner.Stop()
+ }
+ return nil
+ }},
}
infraSteps := []cleanupStep{
diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go
index a6e0551a..5ccd67fb 100644
--- a/backend/cmd/server/wire_gen_test.go
+++ b/backend/cmd/server/wire_gen_test.go
@@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
pricingSvc := service.NewPricingService(cfg, nil)
emailQueueSvc := service.NewEmailQueueService(nil, 1)
- billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
+ billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)
@@ -76,6 +76,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
nil, // scheduledTestRunner
nil, // backupSvc
nil, // paymentOrderExpiry
+ nil, // channelMonitorRunner
)
require.NotPanics(t, func() {
diff --git a/backend/ent/authidentity.go b/backend/ent/authidentity.go
new file mode 100644
index 00000000..5ccfcf19
--- /dev/null
+++ b/backend/ent/authidentity.go
@@ -0,0 +1,266 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentity is the model entity for the AuthIdentity schema.
+type AuthIdentity struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // UserID holds the value of the "user_id" field.
+ UserID int64 `json:"user_id,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // ProviderSubject holds the value of the "provider_subject" field.
+ ProviderSubject string `json:"provider_subject,omitempty"`
+ // VerifiedAt holds the value of the "verified_at" field.
+ VerifiedAt *time.Time `json:"verified_at,omitempty"`
+ // Issuer holds the value of the "issuer" field.
+ Issuer *string `json:"issuer,omitempty"`
+ // Metadata holds the value of the "metadata" field.
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AuthIdentityQuery when eager-loading is set.
+ Edges AuthIdentityEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AuthIdentityEdges holds the relations/edges for other nodes in the graph.
+type AuthIdentityEdges struct {
+ // User holds the value of the user edge.
+ User *User `json:"user,omitempty"`
+ // Channels holds the value of the channels edge.
+ Channels []*AuthIdentityChannel `json:"channels,omitempty"`
+ // AdoptionDecisions holds the value of the adoption_decisions edge.
+ AdoptionDecisions []*IdentityAdoptionDecision `json:"adoption_decisions,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [3]bool
+}
+
+// UserOrErr returns the User value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e AuthIdentityEdges) UserOrErr() (*User, error) {
+ if e.User != nil {
+ return e.User, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: user.Label}
+ }
+ return nil, &NotLoadedError{edge: "user"}
+}
+
+// ChannelsOrErr returns the Channels value or an error if the edge
+// was not loaded in eager-loading.
+func (e AuthIdentityEdges) ChannelsOrErr() ([]*AuthIdentityChannel, error) {
+ if e.loadedTypes[1] {
+ return e.Channels, nil
+ }
+ return nil, &NotLoadedError{edge: "channels"}
+}
+
+// AdoptionDecisionsOrErr returns the AdoptionDecisions value or an error if the edge
+// was not loaded in eager-loading.
+func (e AuthIdentityEdges) AdoptionDecisionsOrErr() ([]*IdentityAdoptionDecision, error) {
+ if e.loadedTypes[2] {
+ return e.AdoptionDecisions, nil
+ }
+ return nil, &NotLoadedError{edge: "adoption_decisions"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AuthIdentity) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case authidentity.FieldMetadata:
+ values[i] = new([]byte)
+ case authidentity.FieldID, authidentity.FieldUserID:
+ values[i] = new(sql.NullInt64)
+ case authidentity.FieldProviderType, authidentity.FieldProviderKey, authidentity.FieldProviderSubject, authidentity.FieldIssuer:
+ values[i] = new(sql.NullString)
+ case authidentity.FieldCreatedAt, authidentity.FieldUpdatedAt, authidentity.FieldVerifiedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AuthIdentity fields.
+func (_m *AuthIdentity) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case authidentity.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case authidentity.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case authidentity.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case authidentity.FieldUserID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field user_id", values[i])
+ } else if value.Valid {
+ _m.UserID = value.Int64
+ }
+ case authidentity.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case authidentity.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case authidentity.FieldProviderSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_subject", values[i])
+ } else if value.Valid {
+ _m.ProviderSubject = value.String
+ }
+ case authidentity.FieldVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field verified_at", values[i])
+ } else if value.Valid {
+ _m.VerifiedAt = new(time.Time)
+ *_m.VerifiedAt = value.Time
+ }
+ case authidentity.FieldIssuer:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field issuer", values[i])
+ } else if value.Valid {
+ _m.Issuer = new(string)
+ *_m.Issuer = value.String
+ }
+ case authidentity.FieldMetadata:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field metadata", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.Metadata); err != nil {
+ return fmt.Errorf("unmarshal field metadata: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentity.
+// This includes values selected through modifiers, order, etc.
+func (_m *AuthIdentity) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryUser queries the "user" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryUser() *UserQuery {
+ return NewAuthIdentityClient(_m.config).QueryUser(_m)
+}
+
+// QueryChannels queries the "channels" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryChannels() *AuthIdentityChannelQuery {
+ return NewAuthIdentityClient(_m.config).QueryChannels(_m)
+}
+
+// QueryAdoptionDecisions queries the "adoption_decisions" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery {
+ return NewAuthIdentityClient(_m.config).QueryAdoptionDecisions(_m)
+}
+
+// Update returns a builder for updating this AuthIdentity.
+// Note that you need to call AuthIdentity.Unwrap() before calling this method if this AuthIdentity
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *AuthIdentity) Update() *AuthIdentityUpdateOne {
+ return NewAuthIdentityClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the AuthIdentity entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *AuthIdentity) Unwrap() *AuthIdentity {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AuthIdentity is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *AuthIdentity) String() string {
+ var builder strings.Builder
+ builder.WriteString("AuthIdentity(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("user_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.UserID))
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("provider_subject=")
+ builder.WriteString(_m.ProviderSubject)
+ builder.WriteString(", ")
+ if v := _m.VerifiedAt; v != nil {
+ builder.WriteString("verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.Issuer; v != nil {
+ builder.WriteString("issuer=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("metadata=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Metadata))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AuthIdentities is a parsable slice of AuthIdentity.
+type AuthIdentities []*AuthIdentity
diff --git a/backend/ent/authidentity/authidentity.go b/backend/ent/authidentity/authidentity.go
new file mode 100644
index 00000000..c90be759
--- /dev/null
+++ b/backend/ent/authidentity/authidentity.go
@@ -0,0 +1,209 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentity
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the authidentity type in the database.
+ Label = "auth_identity"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldUserID holds the string denoting the user_id field in the database.
+ FieldUserID = "user_id"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSubject holds the string denoting the provider_subject field in the database.
+ FieldProviderSubject = "provider_subject"
+ // FieldVerifiedAt holds the string denoting the verified_at field in the database.
+ FieldVerifiedAt = "verified_at"
+ // FieldIssuer holds the string denoting the issuer field in the database.
+ FieldIssuer = "issuer"
+ // FieldMetadata holds the string denoting the metadata field in the database.
+ FieldMetadata = "metadata"
+ // EdgeUser holds the string denoting the user edge name in mutations.
+ EdgeUser = "user"
+ // EdgeChannels holds the string denoting the channels edge name in mutations.
+ EdgeChannels = "channels"
+ // EdgeAdoptionDecisions holds the string denoting the adoption_decisions edge name in mutations.
+ EdgeAdoptionDecisions = "adoption_decisions"
+ // Table holds the table name of the authidentity in the database.
+ Table = "auth_identities"
+ // UserTable is the table that holds the user relation/edge.
+ UserTable = "auth_identities"
+ // UserInverseTable is the table name for the User entity.
+ // It exists in this package in order to avoid circular dependency with the "user" package.
+ UserInverseTable = "users"
+ // UserColumn is the table column denoting the user relation/edge.
+ UserColumn = "user_id"
+ // ChannelsTable is the table that holds the channels relation/edge.
+ ChannelsTable = "auth_identity_channels"
+ // ChannelsInverseTable is the table name for the AuthIdentityChannel entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentitychannel" package.
+ ChannelsInverseTable = "auth_identity_channels"
+ // ChannelsColumn is the table column denoting the channels relation/edge.
+ ChannelsColumn = "identity_id"
+ // AdoptionDecisionsTable is the table that holds the adoption_decisions relation/edge.
+ AdoptionDecisionsTable = "identity_adoption_decisions"
+ // AdoptionDecisionsInverseTable is the table name for the IdentityAdoptionDecision entity.
+ // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package.
+ AdoptionDecisionsInverseTable = "identity_adoption_decisions"
+ // AdoptionDecisionsColumn is the table column denoting the adoption_decisions relation/edge.
+ AdoptionDecisionsColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for authidentity fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldUserID,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldProviderSubject,
+ FieldVerifiedAt,
+ FieldIssuer,
+ FieldMetadata,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ ProviderSubjectValidator func(string) error
+ // DefaultMetadata holds the default value on creation for the "metadata" field.
+ DefaultMetadata func() map[string]interface{}
+)
+
+// OrderOption defines the ordering options for the AuthIdentity queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByUserID orders the results by the user_id field.
+func ByUserID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUserID, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByProviderSubject orders the results by the provider_subject field.
+func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderSubject, opts...).ToFunc()
+}
+
+// ByVerifiedAt orders the results by the verified_at field.
+func ByVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldVerifiedAt, opts...).ToFunc()
+}
+
+// ByIssuer orders the results by the issuer field.
+func ByIssuer(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIssuer, opts...).ToFunc()
+}
+
+// ByUserField orders the results by user field.
+func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByChannelsCount orders the results by channels count.
+func ByChannelsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newChannelsStep(), opts...)
+ }
+}
+
+// ByChannels orders the results by channels terms.
+func ByChannels(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newChannelsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByAdoptionDecisionsCount orders the results by adoption_decisions count.
+func ByAdoptionDecisionsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAdoptionDecisionsStep(), opts...)
+ }
+}
+
+// ByAdoptionDecisions orders the results by adoption_decisions terms.
+func ByAdoptionDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+func newUserStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(UserInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+}
+func newChannelsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(ChannelsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn),
+ )
+}
+func newAdoptionDecisionsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AdoptionDecisionsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn),
+ )
+}
diff --git a/backend/ent/authidentity/where.go b/backend/ent/authidentity/where.go
new file mode 100644
index 00000000..3dbf3178
--- /dev/null
+++ b/backend/ent/authidentity/where.go
@@ -0,0 +1,600 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentity
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
+func UserID(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ.
+func ProviderSubject(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// VerifiedAt applies equality check predicate on the "verified_at" field. It's identical to VerifiedAtEQ.
+func VerifiedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v))
+}
+
+// Issuer applies equality check predicate on the "issuer" field. It's identical to IssuerEQ.
+func Issuer(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// UserIDEQ applies the EQ predicate on the "user_id" field.
+func UserIDEQ(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v))
+}
+
+// UserIDNEQ applies the NEQ predicate on the "user_id" field.
+func UserIDNEQ(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldUserID, v))
+}
+
+// UserIDIn applies the In predicate on the "user_id" field.
+func UserIDIn(vs ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldUserID, vs...))
+}
+
+// UserIDNotIn applies the NotIn predicate on the "user_id" field.
+func UserIDNotIn(vs ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldUserID, vs...))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field.
+func ProviderSubjectEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field.
+func ProviderSubjectNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectIn applies the In predicate on the "provider_subject" field.
+func ProviderSubjectIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field.
+func ProviderSubjectNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectGT applies the GT predicate on the "provider_subject" field.
+func ProviderSubjectGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field.
+func ProviderSubjectGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLT applies the LT predicate on the "provider_subject" field.
+func ProviderSubjectLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field.
+func ProviderSubjectLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field.
+func ProviderSubjectContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field.
+func ProviderSubjectHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field.
+func ProviderSubjectHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field.
+func ProviderSubjectEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field.
+func ProviderSubjectContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderSubject, v))
+}
+
+// VerifiedAtEQ applies the EQ predicate on the "verified_at" field.
+func VerifiedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v))
+}
+
+// VerifiedAtNEQ applies the NEQ predicate on the "verified_at" field.
+func VerifiedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldVerifiedAt, v))
+}
+
+// VerifiedAtIn applies the In predicate on the "verified_at" field.
+func VerifiedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldVerifiedAt, vs...))
+}
+
+// VerifiedAtNotIn applies the NotIn predicate on the "verified_at" field.
+func VerifiedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldVerifiedAt, vs...))
+}
+
+// VerifiedAtGT applies the GT predicate on the "verified_at" field.
+func VerifiedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldVerifiedAt, v))
+}
+
+// VerifiedAtGTE applies the GTE predicate on the "verified_at" field.
+func VerifiedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldVerifiedAt, v))
+}
+
+// VerifiedAtLT applies the LT predicate on the "verified_at" field.
+func VerifiedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldVerifiedAt, v))
+}
+
+// VerifiedAtLTE applies the LTE predicate on the "verified_at" field.
+func VerifiedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldVerifiedAt, v))
+}
+
+// VerifiedAtIsNil applies the IsNil predicate on the "verified_at" field.
+func VerifiedAtIsNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIsNull(FieldVerifiedAt))
+}
+
+// VerifiedAtNotNil applies the NotNil predicate on the "verified_at" field.
+func VerifiedAtNotNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotNull(FieldVerifiedAt))
+}
+
+// IssuerEQ applies the EQ predicate on the "issuer" field.
+func IssuerEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v))
+}
+
+// IssuerNEQ applies the NEQ predicate on the "issuer" field.
+func IssuerNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldIssuer, v))
+}
+
+// IssuerIn applies the In predicate on the "issuer" field.
+func IssuerIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldIssuer, vs...))
+}
+
+// IssuerNotIn applies the NotIn predicate on the "issuer" field.
+func IssuerNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldIssuer, vs...))
+}
+
+// IssuerGT applies the GT predicate on the "issuer" field.
+func IssuerGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldIssuer, v))
+}
+
+// IssuerGTE applies the GTE predicate on the "issuer" field.
+func IssuerGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldIssuer, v))
+}
+
+// IssuerLT applies the LT predicate on the "issuer" field.
+func IssuerLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldIssuer, v))
+}
+
+// IssuerLTE applies the LTE predicate on the "issuer" field.
+func IssuerLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldIssuer, v))
+}
+
+// IssuerContains applies the Contains predicate on the "issuer" field.
+func IssuerContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldIssuer, v))
+}
+
+// IssuerHasPrefix applies the HasPrefix predicate on the "issuer" field.
+func IssuerHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldIssuer, v))
+}
+
+// IssuerHasSuffix applies the HasSuffix predicate on the "issuer" field.
+func IssuerHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldIssuer, v))
+}
+
+// IssuerIsNil applies the IsNil predicate on the "issuer" field.
+func IssuerIsNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIsNull(FieldIssuer))
+}
+
+// IssuerNotNil applies the NotNil predicate on the "issuer" field.
+func IssuerNotNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotNull(FieldIssuer))
+}
+
+// IssuerEqualFold applies the EqualFold predicate on the "issuer" field.
+func IssuerEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldIssuer, v))
+}
+
+// IssuerContainsFold applies the ContainsFold predicate on the "issuer" field.
+func IssuerContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldIssuer, v))
+}
+
+// HasUser applies the HasEdge predicate on the "user" edge.
+func HasUser() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
+func HasUserWith(preds ...predicate.User) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newUserStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasChannels applies the HasEdge predicate on the "channels" edge.
+func HasChannels() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasChannelsWith applies the HasEdge predicate on the "channels" edge with a given conditions (other predicates).
+func HasChannelsWith(preds ...predicate.AuthIdentityChannel) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newChannelsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasAdoptionDecisions applies the HasEdge predicate on the "adoption_decisions" edge.
+func HasAdoptionDecisions() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAdoptionDecisionsWith applies the HasEdge predicate on the "adoption_decisions" edge with a given conditions (other predicates).
+func HasAdoptionDecisionsWith(preds ...predicate.IdentityAdoptionDecision) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newAdoptionDecisionsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.NotPredicates(p))
+}
diff --git a/backend/ent/authidentity_create.go b/backend/ent/authidentity_create.go
new file mode 100644
index 00000000..e287705c
--- /dev/null
+++ b/backend/ent/authidentity_create.go
@@ -0,0 +1,1036 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityCreate is the builder for creating a AuthIdentity entity.
+type AuthIdentityCreate struct {
+ config
+ mutation *AuthIdentityMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *AuthIdentityCreate) SetCreatedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *AuthIdentityCreate) SetUpdatedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetUserID sets the "user_id" field.
+func (_c *AuthIdentityCreate) SetUserID(v int64) *AuthIdentityCreate {
+ _c.mutation.SetUserID(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *AuthIdentityCreate) SetProviderType(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *AuthIdentityCreate) SetProviderKey(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_c *AuthIdentityCreate) SetProviderSubject(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderSubject(v)
+ return _c
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_c *AuthIdentityCreate) SetVerifiedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetVerifiedAt(v)
+ return _c
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetIssuer sets the "issuer" field.
+func (_c *AuthIdentityCreate) SetIssuer(v string) *AuthIdentityCreate {
+ _c.mutation.SetIssuer(v)
+ return _c
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableIssuer(v *string) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetIssuer(*v)
+ }
+ return _c
+}
+
+// SetMetadata sets the "metadata" field.
+func (_c *AuthIdentityCreate) SetMetadata(v map[string]interface{}) *AuthIdentityCreate {
+ _c.mutation.SetMetadata(v)
+ return _c
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_c *AuthIdentityCreate) SetUser(v *User) *AuthIdentityCreate {
+ return _c.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_c *AuthIdentityCreate) AddChannelIDs(ids ...int64) *AuthIdentityCreate {
+ _c.mutation.AddChannelIDs(ids...)
+ return _c
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_c *AuthIdentityCreate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_c *AuthIdentityCreate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityCreate {
+ _c.mutation.AddAdoptionDecisionIDs(ids...)
+ return _c
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_c *AuthIdentityCreate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_c *AuthIdentityCreate) Mutation() *AuthIdentityMutation {
+ return _c.mutation
+}
+
+// Save creates the AuthIdentity in the database.
+func (_c *AuthIdentityCreate) Save(ctx context.Context) (*AuthIdentity, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *AuthIdentityCreate) SaveX(ctx context.Context) *AuthIdentity {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *AuthIdentityCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := authidentity.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := authidentity.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ v := authidentity.DefaultMetadata()
+ _c.mutation.SetMetadata(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *AuthIdentityCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentity.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentity.updated_at"`)}
+ }
+ if _, ok := _c.mutation.UserID(); !ok {
+ return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AuthIdentity.user_id"`)}
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentity.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentity.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderSubject(); !ok {
+ return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "AuthIdentity.provider_subject"`)}
+ }
+ if v, ok := _c.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentity.metadata"`)}
+ }
+ if len(_c.mutation.UserIDs()) == 0 {
+ return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AuthIdentity.user"`)}
+ }
+ return nil
+}
+
+func (_c *AuthIdentityCreate) sqlSave(ctx context.Context) (*AuthIdentity, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *AuthIdentityCreate) createSpec() (*AuthIdentity, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AuthIdentity{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(authidentity.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ _node.ProviderSubject = value
+ }
+ if value, ok := _c.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ _node.VerifiedAt = &value
+ }
+ if value, ok := _c.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ _node.Issuer = &value
+ }
+ if value, ok := _c.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ _node.Metadata = value
+ }
+ if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.UserID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentity.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertOne {
+ _c.conflict = opts
+ return &AuthIdentityUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityCreate) OnConflictColumns(columns ...string) *AuthIdentityUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // AuthIdentityUpsertOne is the builder for "upsert"-ing
+ // one AuthIdentity node.
+ AuthIdentityUpsertOne struct {
+ create *AuthIdentityCreate
+ }
+
+ // AuthIdentityUpsert is the "OnConflict" setter.
+ AuthIdentityUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsert) SetUpdatedAt(v time.Time) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateUpdatedAt() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldUpdatedAt)
+ return u
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsert) SetUserID(v int64) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldUserID, v)
+ return u
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateUserID() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldUserID)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsert) SetProviderType(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderType() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsert) SetProviderKey(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderKey() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderKey)
+ return u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsert) SetProviderSubject(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderSubject, v)
+ return u
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderSubject() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderSubject)
+ return u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsert) SetVerifiedAt(v time.Time) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldVerifiedAt, v)
+ return u
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateVerifiedAt() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldVerifiedAt)
+ return u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsert) ClearVerifiedAt() *AuthIdentityUpsert {
+ u.SetNull(authidentity.FieldVerifiedAt)
+ return u
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsert) SetIssuer(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldIssuer, v)
+ return u
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateIssuer() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldIssuer)
+ return u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsert) ClearIssuer() *AuthIdentityUpsert {
+ u.SetNull(authidentity.FieldIssuer)
+ return u
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldMetadata, v)
+ return u
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateMetadata() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldMetadata)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityUpsertOne) UpdateNewValues() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentity.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityUpsertOne) Ignore() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityUpsertOne) DoNothing() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreate.OnConflict
+// documentation for more info.
+func (u *AuthIdentityUpsertOne) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateUpdatedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsertOne) SetUserID(v int64) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateUserID() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsertOne) SetProviderType(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderType() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsertOne) SetProviderKey(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderKey() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsertOne) SetProviderSubject(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderSubject() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsertOne) SetVerifiedAt(v time.Time) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetVerifiedAt(v)
+ })
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateVerifiedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateVerifiedAt()
+ })
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsertOne) ClearVerifiedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearVerifiedAt()
+ })
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsertOne) SetIssuer(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetIssuer(v)
+ })
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateIssuer() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateIssuer()
+ })
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsertOne) ClearIssuer() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearIssuer()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateMetadata() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *AuthIdentityUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *AuthIdentityUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// AuthIdentityCreateBulk is the builder for creating many AuthIdentity entities in bulk.
+type AuthIdentityCreateBulk struct {
+ config
+ err error
+ builders []*AuthIdentityCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the AuthIdentity entities in the database.
+func (_c *AuthIdentityCreateBulk) Save(ctx context.Context) ([]*AuthIdentity, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*AuthIdentity, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AuthIdentityMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *AuthIdentityCreateBulk) SaveX(ctx context.Context) []*AuthIdentity {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentity.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertBulk {
+ _c.conflict = opts
+ return &AuthIdentityUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityUpsertBulk{
+ create: _c,
+ }
+}
+
+// AuthIdentityUpsertBulk is the builder for "upsert"-ing
+// a bulk of AuthIdentity nodes.
+type AuthIdentityUpsertBulk struct {
+ create *AuthIdentityCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityUpsertBulk) UpdateNewValues() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentity.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityUpsertBulk) Ignore() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityUpsertBulk) DoNothing() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreateBulk.OnConflict
+// documentation for more info.
+func (u *AuthIdentityUpsertBulk) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateUpdatedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsertBulk) SetUserID(v int64) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateUserID() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsertBulk) SetProviderType(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderType() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsertBulk) SetProviderKey(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderKey() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsertBulk) SetProviderSubject(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderSubject() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsertBulk) SetVerifiedAt(v time.Time) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetVerifiedAt(v)
+ })
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateVerifiedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateVerifiedAt()
+ })
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsertBulk) ClearVerifiedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearVerifiedAt()
+ })
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsertBulk) SetIssuer(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetIssuer(v)
+ })
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateIssuer() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateIssuer()
+ })
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsertBulk) ClearIssuer() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearIssuer()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateMetadata() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentity_delete.go b/backend/ent/authidentity_delete.go
new file mode 100644
index 00000000..4f1f6f3c
--- /dev/null
+++ b/backend/ent/authidentity_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityDelete is the builder for deleting a AuthIdentity entity.
+type AuthIdentityDelete struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// Where appends a list predicates to the AuthIdentityDelete builder.
+func (_d *AuthIdentityDelete) Where(ps ...predicate.AuthIdentity) *AuthIdentityDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *AuthIdentityDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *AuthIdentityDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// AuthIdentityDeleteOne is the builder for deleting a single AuthIdentity entity.
+type AuthIdentityDeleteOne struct {
+ _d *AuthIdentityDelete
+}
+
+// Where appends a list predicates to the AuthIdentityDelete builder.
+func (_d *AuthIdentityDeleteOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *AuthIdentityDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{authidentity.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentity_query.go b/backend/ent/authidentity_query.go
new file mode 100644
index 00000000..ff27ef3c
--- /dev/null
+++ b/backend/ent/authidentity_query.go
@@ -0,0 +1,797 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityQuery is the builder for querying AuthIdentity entities.
+type AuthIdentityQuery struct {
+ config
+ ctx *QueryContext
+ order []authidentity.OrderOption
+ inters []Interceptor
+ predicates []predicate.AuthIdentity
+ withUser *UserQuery
+ withChannels *AuthIdentityChannelQuery
+ withAdoptionDecisions *IdentityAdoptionDecisionQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AuthIdentityQuery builder.
+func (_q *AuthIdentityQuery) Where(ps ...predicate.AuthIdentity) *AuthIdentityQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *AuthIdentityQuery) Limit(limit int) *AuthIdentityQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *AuthIdentityQuery) Offset(offset int) *AuthIdentityQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *AuthIdentityQuery) Unique(unique bool) *AuthIdentityQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *AuthIdentityQuery) Order(o ...authidentity.OrderOption) *AuthIdentityQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryUser chains the current query on the "user" edge.
+func (_q *AuthIdentityQuery) QueryUser() *UserQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryChannels chains the current query on the "channels" edge.
+func (_q *AuthIdentityQuery) QueryChannels() *AuthIdentityChannelQuery {
+ query := (&AuthIdentityChannelClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecisions chains the current query on the "adoption_decisions" edge.
+func (_q *AuthIdentityQuery) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AuthIdentity entity from the query.
+// Returns a *NotFoundError when no AuthIdentity was found.
+func (_q *AuthIdentityQuery) First(ctx context.Context) (*AuthIdentity, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{authidentity.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *AuthIdentityQuery) FirstX(ctx context.Context) *AuthIdentity {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AuthIdentity ID from the query.
+// Returns a *NotFoundError when no AuthIdentity ID was found.
+func (_q *AuthIdentityQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{authidentity.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *AuthIdentityQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AuthIdentity entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AuthIdentity entity is found.
+// Returns a *NotFoundError when no AuthIdentity entities are found.
+func (_q *AuthIdentityQuery) Only(ctx context.Context) (*AuthIdentity, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{authidentity.Label}
+ default:
+ return nil, &NotSingularError{authidentity.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *AuthIdentityQuery) OnlyX(ctx context.Context) *AuthIdentity {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AuthIdentity ID in the query.
+// Returns a *NotSingularError when more than one AuthIdentity ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *AuthIdentityQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{authidentity.Label}
+ default:
+ err = &NotSingularError{authidentity.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *AuthIdentityQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AuthIdentities.
+func (_q *AuthIdentityQuery) All(ctx context.Context) ([]*AuthIdentity, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AuthIdentity, *AuthIdentityQuery]()
+ return withInterceptors[[]*AuthIdentity](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *AuthIdentityQuery) AllX(ctx context.Context) []*AuthIdentity {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AuthIdentity IDs.
+func (_q *AuthIdentityQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(authidentity.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *AuthIdentityQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *AuthIdentityQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *AuthIdentityQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *AuthIdentityQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *AuthIdentityQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AuthIdentityQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *AuthIdentityQuery) Clone() *AuthIdentityQuery {
+ if _q == nil {
+ return nil
+ }
+ return &AuthIdentityQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]authidentity.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.AuthIdentity{}, _q.predicates...),
+ withUser: _q.withUser.Clone(),
+ withChannels: _q.withChannels.Clone(),
+ withAdoptionDecisions: _q.withAdoptionDecisions.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithUser tells the query-builder to eager-load the nodes that are connected to
+// the "user" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithUser(opts ...func(*UserQuery)) *AuthIdentityQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withUser = query
+ return _q
+}
+
+// WithChannels tells the query-builder to eager-load the nodes that are connected to
+// the "channels" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithChannels(opts ...func(*AuthIdentityChannelQuery)) *AuthIdentityQuery {
+ query := (&AuthIdentityChannelClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withChannels = query
+ return _q
+}
+
+// WithAdoptionDecisions tells the query-builder to eager-load the nodes that are connected to
+// the "adoption_decisions" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithAdoptionDecisions(opts ...func(*IdentityAdoptionDecisionQuery)) *AuthIdentityQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAdoptionDecisions = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AuthIdentity.Query().
+// GroupBy(authidentity.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *AuthIdentityQuery) GroupBy(field string, fields ...string) *AuthIdentityGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AuthIdentityGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = authidentity.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.AuthIdentity.Query().
+// Select(authidentity.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *AuthIdentityQuery) Select(fields ...string) *AuthIdentitySelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &AuthIdentitySelect{AuthIdentityQuery: _q}
+ sbuild.label = authidentity.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AuthIdentitySelect configured with the given aggregations.
+func (_q *AuthIdentityQuery) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *AuthIdentityQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !authidentity.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *AuthIdentityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentity, error) {
+ var (
+ nodes = []*AuthIdentity{}
+ _spec = _q.querySpec()
+ loadedTypes = [3]bool{
+ _q.withUser != nil,
+ _q.withChannels != nil,
+ _q.withAdoptionDecisions != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AuthIdentity).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AuthIdentity{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withUser; query != nil {
+ if err := _q.loadUser(ctx, query, nodes, nil,
+ func(n *AuthIdentity, e *User) { n.Edges.User = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withChannels; query != nil {
+ if err := _q.loadChannels(ctx, query, nodes,
+ func(n *AuthIdentity) { n.Edges.Channels = []*AuthIdentityChannel{} },
+ func(n *AuthIdentity, e *AuthIdentityChannel) { n.Edges.Channels = append(n.Edges.Channels, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withAdoptionDecisions; query != nil {
+ if err := _q.loadAdoptionDecisions(ctx, query, nodes,
+ func(n *AuthIdentity) { n.Edges.AdoptionDecisions = []*IdentityAdoptionDecision{} },
+ func(n *AuthIdentity, e *IdentityAdoptionDecision) {
+ n.Edges.AdoptionDecisions = append(n.Edges.AdoptionDecisions, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *AuthIdentityQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *User)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*AuthIdentity)
+ for i := range nodes {
+ fk := nodes[i].UserID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(user.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *AuthIdentityQuery) loadChannels(ctx context.Context, query *AuthIdentityChannelQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *AuthIdentityChannel)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*AuthIdentity)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(authidentitychannel.FieldIdentityID)
+ }
+ query.Where(predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(authidentity.ChannelsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.IdentityID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *AuthIdentityQuery) loadAdoptionDecisions(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *IdentityAdoptionDecision)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*AuthIdentity)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(identityadoptiondecision.FieldIdentityID)
+ }
+ query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(authidentity.AdoptionDecisionsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.IdentityID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "identity_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *AuthIdentityQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *AuthIdentityQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID)
+ for i := range fields {
+ if fields[i] != authidentity.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withUser != nil {
+ _spec.Node.AddColumnOnce(authidentity.FieldUserID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *AuthIdentityQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(authidentity.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = authidentity.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *AuthIdentityQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *AuthIdentityQuery) ForShare(opts ...sql.LockOption) *AuthIdentityQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// AuthIdentityGroupBy is the group-by builder for AuthIdentity entities.
+type AuthIdentityGroupBy struct {
+ selector
+ build *AuthIdentityQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *AuthIdentityGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *AuthIdentityGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentityGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *AuthIdentityGroupBy) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AuthIdentitySelect is the builder for selecting fields of AuthIdentity entities.
+type AuthIdentitySelect struct {
+ *AuthIdentityQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *AuthIdentitySelect) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *AuthIdentitySelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentitySelect](ctx, _s.AuthIdentityQuery, _s, _s.inters, v)
+}
+
+func (_s *AuthIdentitySelect) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/authidentity_update.go b/backend/ent/authidentity_update.go
new file mode 100644
index 00000000..c457470b
--- /dev/null
+++ b/backend/ent/authidentity_update.go
@@ -0,0 +1,923 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityUpdate is the builder for updating AuthIdentity entities.
+type AuthIdentityUpdate struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// Where appends a list predicates to the AuthIdentityUpdate builder.
+func (_u *AuthIdentityUpdate) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityUpdate) SetUpdatedAt(v time.Time) *AuthIdentityUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *AuthIdentityUpdate) SetUserID(v int64) *AuthIdentityUpdate {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableUserID(v *int64) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityUpdate) SetProviderType(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderType(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityUpdate) SetProviderKey(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderKey(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *AuthIdentityUpdate) SetProviderSubject(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderSubject(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_u *AuthIdentityUpdate) SetVerifiedAt(v time.Time) *AuthIdentityUpdate {
+ _u.mutation.SetVerifiedAt(v)
+ return _u
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (_u *AuthIdentityUpdate) ClearVerifiedAt() *AuthIdentityUpdate {
+ _u.mutation.ClearVerifiedAt()
+ return _u
+}
+
+// SetIssuer sets the "issuer" field.
+func (_u *AuthIdentityUpdate) SetIssuer(v string) *AuthIdentityUpdate {
+ _u.mutation.SetIssuer(v)
+ return _u
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableIssuer(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetIssuer(*v)
+ }
+ return _u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (_u *AuthIdentityUpdate) ClearIssuer() *AuthIdentityUpdate {
+ _u.mutation.ClearIssuer()
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityUpdate {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *AuthIdentityUpdate) SetUser(v *User) *AuthIdentityUpdate {
+ return _u.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_u *AuthIdentityUpdate) AddChannelIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.AddChannelIDs(ids...)
+ return _u
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_u *AuthIdentityUpdate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.AddAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_u *AuthIdentityUpdate) Mutation() *AuthIdentityMutation {
+ return _u.mutation
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *AuthIdentityUpdate) ClearUser() *AuthIdentityUpdate {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdate) ClearChannels() *AuthIdentityUpdate {
+ _u.mutation.ClearChannels()
+ return _u
+}
+
+// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs.
+func (_u *AuthIdentityUpdate) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.RemoveChannelIDs(ids...)
+ return _u
+}
+
+// RemoveChannels removes "channels" edges to AuthIdentityChannel entities.
+func (_u *AuthIdentityUpdate) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveChannelIDs(ids...)
+}
+
+// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdate) ClearAdoptionDecisions() *AuthIdentityUpdate {
+ _u.mutation.ClearAdoptionDecisions()
+ return _u
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs.
+func (_u *AuthIdentityUpdate) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.RemoveAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities.
+func (_u *AuthIdentityUpdate) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAdoptionDecisionIDs(ids...)
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *AuthIdentityUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *AuthIdentityUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentity.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityUpdate) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.VerifiedAtCleared() {
+ _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ }
+ if _u.mutation.IssuerCleared() {
+ _spec.ClearField(authidentity.FieldIssuer, field.TypeString)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentity.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// AuthIdentityUpdateOne is the builder for updating a single AuthIdentity entity.
+type AuthIdentityUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *AuthIdentityUpdateOne) SetUserID(v int64) *AuthIdentityUpdateOne {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableUserID(v *int64) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityUpdateOne) SetProviderType(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderType(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityUpdateOne) SetProviderKey(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *AuthIdentityUpdateOne) SetProviderSubject(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderSubject(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_u *AuthIdentityUpdateOne) SetVerifiedAt(v time.Time) *AuthIdentityUpdateOne {
+ _u.mutation.SetVerifiedAt(v)
+ return _u
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (_u *AuthIdentityUpdateOne) ClearVerifiedAt() *AuthIdentityUpdateOne {
+ _u.mutation.ClearVerifiedAt()
+ return _u
+}
+
+// SetIssuer sets the "issuer" field.
+func (_u *AuthIdentityUpdateOne) SetIssuer(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetIssuer(v)
+ return _u
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableIssuer(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetIssuer(*v)
+ }
+ return _u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (_u *AuthIdentityUpdateOne) ClearIssuer() *AuthIdentityUpdateOne {
+ _u.mutation.ClearIssuer()
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpdateOne {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *AuthIdentityUpdateOne) SetUser(v *User) *AuthIdentityUpdateOne {
+ return _u.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_u *AuthIdentityUpdateOne) AddChannelIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.AddChannelIDs(ids...)
+ return _u
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdateOne) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_u *AuthIdentityUpdateOne) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.AddAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdateOne) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_u *AuthIdentityUpdateOne) Mutation() *AuthIdentityMutation {
+ return _u.mutation
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *AuthIdentityUpdateOne) ClearUser() *AuthIdentityUpdateOne {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdateOne) ClearChannels() *AuthIdentityUpdateOne {
+ _u.mutation.ClearChannels()
+ return _u
+}
+
+// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs.
+func (_u *AuthIdentityUpdateOne) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.RemoveChannelIDs(ids...)
+ return _u
+}
+
+// RemoveChannels removes "channels" edges to AuthIdentityChannel entities.
+func (_u *AuthIdentityUpdateOne) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveChannelIDs(ids...)
+}
+
+// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdateOne) ClearAdoptionDecisions() *AuthIdentityUpdateOne {
+ _u.mutation.ClearAdoptionDecisions()
+ return _u
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs.
+func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.RemoveAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities.
+func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAdoptionDecisionIDs(ids...)
+}
+
+// Where appends a list predicates to the AuthIdentityUpdate builder.
+func (_u *AuthIdentityUpdateOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *AuthIdentityUpdateOne) Select(field string, fields ...string) *AuthIdentityUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated AuthIdentity entity.
+func (_u *AuthIdentityUpdateOne) Save(ctx context.Context) (*AuthIdentity, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityUpdateOne) SaveX(ctx context.Context) *AuthIdentity {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *AuthIdentityUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentity.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityUpdateOne) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentity, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentity.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID)
+ for _, f := range fields {
+ if !authidentity.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != authidentity.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.VerifiedAtCleared() {
+ _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ }
+ if _u.mutation.IssuerCleared() {
+ _spec.ClearField(authidentity.FieldIssuer, field.TypeString)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AuthIdentity{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentity.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/authidentitychannel.go b/backend/ent/authidentitychannel.go
new file mode 100644
index 00000000..1ff3e5d1
--- /dev/null
+++ b/backend/ent/authidentitychannel.go
@@ -0,0 +1,228 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+)
+
+// AuthIdentityChannel is the model entity for the AuthIdentityChannel schema.
+type AuthIdentityChannel struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // IdentityID holds the value of the "identity_id" field.
+ IdentityID int64 `json:"identity_id,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // Channel holds the value of the "channel" field.
+ Channel string `json:"channel,omitempty"`
+ // ChannelAppID holds the value of the "channel_app_id" field.
+ ChannelAppID string `json:"channel_app_id,omitempty"`
+ // ChannelSubject holds the value of the "channel_subject" field.
+ ChannelSubject string `json:"channel_subject,omitempty"`
+ // Metadata holds the value of the "metadata" field.
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AuthIdentityChannelQuery when eager-loading is set.
+ Edges AuthIdentityChannelEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AuthIdentityChannelEdges holds the relations/edges for other nodes in the graph.
+type AuthIdentityChannelEdges struct {
+ // Identity holds the value of the identity edge.
+ Identity *AuthIdentity `json:"identity,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// IdentityOrErr returns the Identity value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e AuthIdentityChannelEdges) IdentityOrErr() (*AuthIdentity, error) {
+ if e.Identity != nil {
+ return e.Identity, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: authidentity.Label}
+ }
+ return nil, &NotLoadedError{edge: "identity"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AuthIdentityChannel) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case authidentitychannel.FieldMetadata:
+ values[i] = new([]byte)
+ case authidentitychannel.FieldID, authidentitychannel.FieldIdentityID:
+ values[i] = new(sql.NullInt64)
+ case authidentitychannel.FieldProviderType, authidentitychannel.FieldProviderKey, authidentitychannel.FieldChannel, authidentitychannel.FieldChannelAppID, authidentitychannel.FieldChannelSubject:
+ values[i] = new(sql.NullString)
+ case authidentitychannel.FieldCreatedAt, authidentitychannel.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AuthIdentityChannel fields.
+func (_m *AuthIdentityChannel) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case authidentitychannel.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case authidentitychannel.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case authidentitychannel.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case authidentitychannel.FieldIdentityID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field identity_id", values[i])
+ } else if value.Valid {
+ _m.IdentityID = value.Int64
+ }
+ case authidentitychannel.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case authidentitychannel.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case authidentitychannel.FieldChannel:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel", values[i])
+ } else if value.Valid {
+ _m.Channel = value.String
+ }
+ case authidentitychannel.FieldChannelAppID:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel_app_id", values[i])
+ } else if value.Valid {
+ _m.ChannelAppID = value.String
+ }
+ case authidentitychannel.FieldChannelSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel_subject", values[i])
+ } else if value.Valid {
+ _m.ChannelSubject = value.String
+ }
+ case authidentitychannel.FieldMetadata:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field metadata", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.Metadata); err != nil {
+ return fmt.Errorf("unmarshal field metadata: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentityChannel.
+// This includes values selected through modifiers, order, etc.
+func (_m *AuthIdentityChannel) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryIdentity queries the "identity" edge of the AuthIdentityChannel entity.
+func (_m *AuthIdentityChannel) QueryIdentity() *AuthIdentityQuery {
+ return NewAuthIdentityChannelClient(_m.config).QueryIdentity(_m)
+}
+
+// Update returns a builder for updating this AuthIdentityChannel.
+// Note that you need to call AuthIdentityChannel.Unwrap() before calling this method if this AuthIdentityChannel
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *AuthIdentityChannel) Update() *AuthIdentityChannelUpdateOne {
+ return NewAuthIdentityChannelClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the AuthIdentityChannel entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *AuthIdentityChannel) Unwrap() *AuthIdentityChannel {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AuthIdentityChannel is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *AuthIdentityChannel) String() string {
+ var builder strings.Builder
+ builder.WriteString("AuthIdentityChannel(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("identity_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.IdentityID))
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("channel=")
+ builder.WriteString(_m.Channel)
+ builder.WriteString(", ")
+ builder.WriteString("channel_app_id=")
+ builder.WriteString(_m.ChannelAppID)
+ builder.WriteString(", ")
+ builder.WriteString("channel_subject=")
+ builder.WriteString(_m.ChannelSubject)
+ builder.WriteString(", ")
+ builder.WriteString("metadata=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Metadata))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AuthIdentityChannels is a parsable slice of AuthIdentityChannel.
+type AuthIdentityChannels []*AuthIdentityChannel
diff --git a/backend/ent/authidentitychannel/authidentitychannel.go b/backend/ent/authidentitychannel/authidentitychannel.go
new file mode 100644
index 00000000..7dcc98bb
--- /dev/null
+++ b/backend/ent/authidentitychannel/authidentitychannel.go
@@ -0,0 +1,153 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentitychannel
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the authidentitychannel type in the database.
+ Label = "auth_identity_channel"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldIdentityID holds the string denoting the identity_id field in the database.
+ FieldIdentityID = "identity_id"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldChannel holds the string denoting the channel field in the database.
+ FieldChannel = "channel"
+ // FieldChannelAppID holds the string denoting the channel_app_id field in the database.
+ FieldChannelAppID = "channel_app_id"
+ // FieldChannelSubject holds the string denoting the channel_subject field in the database.
+ FieldChannelSubject = "channel_subject"
+ // FieldMetadata holds the string denoting the metadata field in the database.
+ FieldMetadata = "metadata"
+ // EdgeIdentity holds the string denoting the identity edge name in mutations.
+ EdgeIdentity = "identity"
+ // Table holds the table name of the authidentitychannel in the database.
+ Table = "auth_identity_channels"
+ // IdentityTable is the table that holds the identity relation/edge.
+ IdentityTable = "auth_identity_channels"
+ // IdentityInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ IdentityInverseTable = "auth_identities"
+ // IdentityColumn is the table column denoting the identity relation/edge.
+ IdentityColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for authidentitychannel fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldIdentityID,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldChannel,
+ FieldChannelAppID,
+ FieldChannelSubject,
+ FieldMetadata,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ChannelValidator is a validator for the "channel" field. It is called by the builders before save.
+ ChannelValidator func(string) error
+ // ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save.
+ ChannelAppIDValidator func(string) error
+ // ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save.
+ ChannelSubjectValidator func(string) error
+ // DefaultMetadata holds the default value on creation for the "metadata" field.
+ DefaultMetadata func() map[string]interface{}
+)
+
+// OrderOption defines the ordering options for the AuthIdentityChannel queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByIdentityID orders the results by the identity_id field.
+func ByIdentityID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIdentityID, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByChannel orders the results by the channel field.
+func ByChannel(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannel, opts...).ToFunc()
+}
+
+// ByChannelAppID orders the results by the channel_app_id field.
+func ByChannelAppID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannelAppID, opts...).ToFunc()
+}
+
+// ByChannelSubject orders the results by the channel_subject field.
+func ByChannelSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannelSubject, opts...).ToFunc()
+}
+
+// ByIdentityField orders the results by identity field.
+func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newIdentityStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(IdentityInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+}
diff --git a/backend/ent/authidentitychannel/where.go b/backend/ent/authidentitychannel/where.go
new file mode 100644
index 00000000..827dc384
--- /dev/null
+++ b/backend/ent/authidentitychannel/where.go
@@ -0,0 +1,559 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentitychannel
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ.
+func IdentityID(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// Channel applies equality check predicate on the "channel" field. It's identical to ChannelEQ.
+func Channel(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v))
+}
+
+// ChannelAppID applies equality check predicate on the "channel_app_id" field. It's identical to ChannelAppIDEQ.
+func ChannelAppID(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v))
+}
+
+// ChannelSubject applies equality check predicate on the "channel_subject" field. It's identical to ChannelSubjectEQ.
+func ChannelSubject(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// IdentityIDEQ applies the EQ predicate on the "identity_id" field.
+func IdentityIDEQ(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field.
+func IdentityIDNEQ(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldIdentityID, v))
+}
+
+// IdentityIDIn applies the In predicate on the "identity_id" field.
+func IdentityIDIn(vs ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field.
+func IdentityIDNotIn(vs ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldIdentityID, vs...))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ChannelEQ applies the EQ predicate on the "channel" field.
+func ChannelEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v))
+}
+
+// ChannelNEQ applies the NEQ predicate on the "channel" field.
+func ChannelNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannel, v))
+}
+
+// ChannelIn applies the In predicate on the "channel" field.
+func ChannelIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannel, vs...))
+}
+
+// ChannelNotIn applies the NotIn predicate on the "channel" field.
+func ChannelNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannel, vs...))
+}
+
+// ChannelGT applies the GT predicate on the "channel" field.
+func ChannelGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannel, v))
+}
+
+// ChannelGTE applies the GTE predicate on the "channel" field.
+func ChannelGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannel, v))
+}
+
+// ChannelLT applies the LT predicate on the "channel" field.
+func ChannelLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannel, v))
+}
+
+// ChannelLTE applies the LTE predicate on the "channel" field.
+func ChannelLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannel, v))
+}
+
+// ChannelContains applies the Contains predicate on the "channel" field.
+func ChannelContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannel, v))
+}
+
+// ChannelHasPrefix applies the HasPrefix predicate on the "channel" field.
+func ChannelHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannel, v))
+}
+
+// ChannelHasSuffix applies the HasSuffix predicate on the "channel" field.
+func ChannelHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannel, v))
+}
+
+// ChannelEqualFold applies the EqualFold predicate on the "channel" field.
+func ChannelEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannel, v))
+}
+
+// ChannelContainsFold applies the ContainsFold predicate on the "channel" field.
+func ChannelContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannel, v))
+}
+
+// ChannelAppIDEQ applies the EQ predicate on the "channel_app_id" field.
+func ChannelAppIDEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v))
+}
+
+// ChannelAppIDNEQ applies the NEQ predicate on the "channel_app_id" field.
+func ChannelAppIDNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelAppID, v))
+}
+
+// ChannelAppIDIn applies the In predicate on the "channel_app_id" field.
+func ChannelAppIDIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelAppID, vs...))
+}
+
+// ChannelAppIDNotIn applies the NotIn predicate on the "channel_app_id" field.
+func ChannelAppIDNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelAppID, vs...))
+}
+
+// ChannelAppIDGT applies the GT predicate on the "channel_app_id" field.
+func ChannelAppIDGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelAppID, v))
+}
+
+// ChannelAppIDGTE applies the GTE predicate on the "channel_app_id" field.
+func ChannelAppIDGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelAppID, v))
+}
+
+// ChannelAppIDLT applies the LT predicate on the "channel_app_id" field.
+func ChannelAppIDLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelAppID, v))
+}
+
+// ChannelAppIDLTE applies the LTE predicate on the "channel_app_id" field.
+func ChannelAppIDLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelAppID, v))
+}
+
+// ChannelAppIDContains applies the Contains predicate on the "channel_app_id" field.
+func ChannelAppIDContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelAppID, v))
+}
+
+// ChannelAppIDHasPrefix applies the HasPrefix predicate on the "channel_app_id" field.
+func ChannelAppIDHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelAppID, v))
+}
+
+// ChannelAppIDHasSuffix applies the HasSuffix predicate on the "channel_app_id" field.
+func ChannelAppIDHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelAppID, v))
+}
+
+// ChannelAppIDEqualFold applies the EqualFold predicate on the "channel_app_id" field.
+func ChannelAppIDEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelAppID, v))
+}
+
+// ChannelAppIDContainsFold applies the ContainsFold predicate on the "channel_app_id" field.
+func ChannelAppIDContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelAppID, v))
+}
+
+// ChannelSubjectEQ applies the EQ predicate on the "channel_subject" field.
+func ChannelSubjectEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v))
+}
+
+// ChannelSubjectNEQ applies the NEQ predicate on the "channel_subject" field.
+func ChannelSubjectNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelSubject, v))
+}
+
+// ChannelSubjectIn applies the In predicate on the "channel_subject" field.
+func ChannelSubjectIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelSubject, vs...))
+}
+
+// ChannelSubjectNotIn applies the NotIn predicate on the "channel_subject" field.
+func ChannelSubjectNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelSubject, vs...))
+}
+
+// ChannelSubjectGT applies the GT predicate on the "channel_subject" field.
+func ChannelSubjectGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelSubject, v))
+}
+
+// ChannelSubjectGTE applies the GTE predicate on the "channel_subject" field.
+func ChannelSubjectGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelSubject, v))
+}
+
+// ChannelSubjectLT applies the LT predicate on the "channel_subject" field.
+func ChannelSubjectLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelSubject, v))
+}
+
+// ChannelSubjectLTE applies the LTE predicate on the "channel_subject" field.
+func ChannelSubjectLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelSubject, v))
+}
+
+// ChannelSubjectContains applies the Contains predicate on the "channel_subject" field.
+func ChannelSubjectContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelSubject, v))
+}
+
+// ChannelSubjectHasPrefix applies the HasPrefix predicate on the "channel_subject" field.
+func ChannelSubjectHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelSubject, v))
+}
+
+// ChannelSubjectHasSuffix applies the HasSuffix predicate on the "channel_subject" field.
+func ChannelSubjectHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelSubject, v))
+}
+
+// ChannelSubjectEqualFold applies the EqualFold predicate on the "channel_subject" field.
+func ChannelSubjectEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelSubject, v))
+}
+
+// ChannelSubjectContainsFold applies the ContainsFold predicate on the "channel_subject" field.
+func ChannelSubjectContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelSubject, v))
+}
+
+// HasIdentity applies the HasEdge predicate on the "identity" edge.
+func HasIdentity() predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates).
+func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ step := newIdentityStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.NotPredicates(p))
+}
diff --git a/backend/ent/authidentitychannel_create.go b/backend/ent/authidentitychannel_create.go
new file mode 100644
index 00000000..4ce28479
--- /dev/null
+++ b/backend/ent/authidentitychannel_create.go
@@ -0,0 +1,932 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+)
+
+// AuthIdentityChannelCreate is the builder for creating a AuthIdentityChannel entity.
+type AuthIdentityChannelCreate struct {
+ config
+ mutation *AuthIdentityChannelMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *AuthIdentityChannelCreate) SetCreatedAt(v time.Time) *AuthIdentityChannelCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *AuthIdentityChannelCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityChannelCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *AuthIdentityChannelCreate) SetUpdatedAt(v time.Time) *AuthIdentityChannelCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *AuthIdentityChannelCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityChannelCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_c *AuthIdentityChannelCreate) SetIdentityID(v int64) *AuthIdentityChannelCreate {
+ _c.mutation.SetIdentityID(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *AuthIdentityChannelCreate) SetProviderType(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *AuthIdentityChannelCreate) SetProviderKey(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetChannel sets the "channel" field.
+func (_c *AuthIdentityChannelCreate) SetChannel(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannel(v)
+ return _c
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_c *AuthIdentityChannelCreate) SetChannelAppID(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannelAppID(v)
+ return _c
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_c *AuthIdentityChannelCreate) SetChannelSubject(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannelSubject(v)
+ return _c
+}
+
+// SetMetadata sets the "metadata" field.
+func (_c *AuthIdentityChannelCreate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelCreate {
+ _c.mutation.SetMetadata(v)
+ return _c
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_c *AuthIdentityChannelCreate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelCreate {
+ return _c.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_c *AuthIdentityChannelCreate) Mutation() *AuthIdentityChannelMutation {
+ return _c.mutation
+}
+
+// Save creates the AuthIdentityChannel in the database.
+func (_c *AuthIdentityChannelCreate) Save(ctx context.Context) (*AuthIdentityChannel, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *AuthIdentityChannelCreate) SaveX(ctx context.Context) *AuthIdentityChannel {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityChannelCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *AuthIdentityChannelCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := authidentitychannel.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ v := authidentitychannel.DefaultMetadata()
+ _c.mutation.SetMetadata(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *AuthIdentityChannelCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.updated_at"`)}
+ }
+ if _, ok := _c.mutation.IdentityID(); !ok {
+ return &ValidationError{Name: "identity_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.identity_id"`)}
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Channel(); !ok {
+ return &ValidationError{Name: "channel", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel"`)}
+ }
+ if v, ok := _c.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ChannelAppID(); !ok {
+ return &ValidationError{Name: "channel_app_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_app_id"`)}
+ }
+ if v, ok := _c.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ChannelSubject(); !ok {
+ return &ValidationError{Name: "channel_subject", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_subject"`)}
+ }
+ if v, ok := _c.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentityChannel.metadata"`)}
+ }
+ if len(_c.mutation.IdentityIDs()) == 0 {
+ return &ValidationError{Name: "identity", err: errors.New(`ent: missing required edge "AuthIdentityChannel.identity"`)}
+ }
+ return nil
+}
+
+func (_c *AuthIdentityChannelCreate) sqlSave(ctx context.Context) (*AuthIdentityChannel, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *AuthIdentityChannelCreate) createSpec() (*AuthIdentityChannel, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AuthIdentityChannel{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ _node.Channel = value
+ }
+ if value, ok := _c.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ _node.ChannelAppID = value
+ }
+ if value, ok := _c.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ _node.ChannelSubject = value
+ }
+ if value, ok := _c.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ _node.Metadata = value
+ }
+ if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.IdentityID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentityChannel.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityChannelUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertOne {
+ _c.conflict = opts
+ return &AuthIdentityChannelUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreate) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityChannelUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // AuthIdentityChannelUpsertOne is the builder for "upsert"-ing
+ // one AuthIdentityChannel node.
+ AuthIdentityChannelUpsertOne struct {
+ create *AuthIdentityChannelCreate
+ }
+
+ // AuthIdentityChannelUpsert is the "OnConflict" setter.
+ AuthIdentityChannelUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsert) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateUpdatedAt() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldUpdatedAt)
+ return u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsert) SetIdentityID(v int64) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldIdentityID, v)
+ return u
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateIdentityID() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldIdentityID)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsert) SetProviderType(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateProviderType() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsert) SetProviderKey(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateProviderKey() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldProviderKey)
+ return u
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsert) SetChannel(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannel, v)
+ return u
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannel() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannel)
+ return u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsert) SetChannelAppID(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannelAppID, v)
+ return u
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannelAppID() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannelAppID)
+ return u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsert) SetChannelSubject(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannelSubject, v)
+ return u
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannelSubject() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannelSubject)
+ return u
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldMetadata, v)
+ return u
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateMetadata() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldMetadata)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertOne) UpdateNewValues() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentitychannel.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertOne) Ignore() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityChannelUpsertOne) DoNothing() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreate.OnConflict
+// documentation for more info.
+func (u *AuthIdentityChannelUpsertOne) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityChannelUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateUpdatedAt() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsertOne) SetIdentityID(v int64) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateIdentityID() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsertOne) SetProviderType(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateProviderType() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsertOne) SetProviderKey(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateProviderKey() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannel(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannel(v)
+ })
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannel() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannel()
+ })
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannelAppID(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelAppID(v)
+ })
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannelAppID() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelAppID()
+ })
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannelSubject(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelSubject(v)
+ })
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannelSubject() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelSubject()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateMetadata() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityChannelUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityChannelCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *AuthIdentityChannelUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// AuthIdentityChannelCreateBulk is the builder for creating many AuthIdentityChannel entities in bulk.
+type AuthIdentityChannelCreateBulk struct {
+ config
+ err error
+ builders []*AuthIdentityChannelCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the AuthIdentityChannel entities in the database.
+func (_c *AuthIdentityChannelCreateBulk) Save(ctx context.Context) ([]*AuthIdentityChannel, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*AuthIdentityChannel, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AuthIdentityChannelMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreateBulk) SaveX(ctx context.Context) []*AuthIdentityChannel {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityChannelCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentityChannel.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityChannelUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertBulk {
+ _c.conflict = opts
+ return &AuthIdentityChannelUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityChannelUpsertBulk{
+ create: _c,
+ }
+}
+
+// AuthIdentityChannelUpsertBulk is the builder for "upsert"-ing
+// a bulk of AuthIdentityChannel nodes.
+type AuthIdentityChannelUpsertBulk struct {
+ create *AuthIdentityChannelCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertBulk) UpdateNewValues() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentitychannel.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertBulk) Ignore() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityChannelUpsertBulk) DoNothing() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreateBulk.OnConflict
+// documentation for more info.
+func (u *AuthIdentityChannelUpsertBulk) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityChannelUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateUpdatedAt() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsertBulk) SetIdentityID(v int64) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateIdentityID() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsertBulk) SetProviderType(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateProviderType() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsertBulk) SetProviderKey(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateProviderKey() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannel(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannel(v)
+ })
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannel() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannel()
+ })
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannelAppID(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelAppID(v)
+ })
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannelAppID() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelAppID()
+ })
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannelSubject(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelSubject(v)
+ })
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannelSubject() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelSubject()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateMetadata() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityChannelUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityChannelCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityChannelCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentitychannel_delete.go b/backend/ent/authidentitychannel_delete.go
new file mode 100644
index 00000000..1a4acac5
--- /dev/null
+++ b/backend/ent/authidentitychannel_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelDelete is the builder for deleting a AuthIdentityChannel entity.
+type AuthIdentityChannelDelete struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// Where appends a list predicates to the AuthIdentityChannelDelete builder.
+func (_d *AuthIdentityChannelDelete) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *AuthIdentityChannelDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityChannelDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *AuthIdentityChannelDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// AuthIdentityChannelDeleteOne is the builder for deleting a single AuthIdentityChannel entity.
+type AuthIdentityChannelDeleteOne struct {
+ _d *AuthIdentityChannelDelete
+}
+
+// Where appends a list predicates to the AuthIdentityChannelDelete builder.
+func (_d *AuthIdentityChannelDeleteOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *AuthIdentityChannelDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{authidentitychannel.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityChannelDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentitychannel_query.go b/backend/ent/authidentitychannel_query.go
new file mode 100644
index 00000000..7a202b7f
--- /dev/null
+++ b/backend/ent/authidentitychannel_query.go
@@ -0,0 +1,643 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelQuery is the builder for querying AuthIdentityChannel entities.
+type AuthIdentityChannelQuery struct {
+ config
+ ctx *QueryContext
+ order []authidentitychannel.OrderOption
+ inters []Interceptor
+ predicates []predicate.AuthIdentityChannel
+ withIdentity *AuthIdentityQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AuthIdentityChannelQuery builder.
+func (_q *AuthIdentityChannelQuery) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *AuthIdentityChannelQuery) Limit(limit int) *AuthIdentityChannelQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *AuthIdentityChannelQuery) Offset(offset int) *AuthIdentityChannelQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *AuthIdentityChannelQuery) Unique(unique bool) *AuthIdentityChannelQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *AuthIdentityChannelQuery) Order(o ...authidentitychannel.OrderOption) *AuthIdentityChannelQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryIdentity chains the current query on the "identity" edge.
+func (_q *AuthIdentityChannelQuery) QueryIdentity() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AuthIdentityChannel entity from the query.
+// Returns a *NotFoundError when no AuthIdentityChannel was found.
+func (_q *AuthIdentityChannelQuery) First(ctx context.Context) (*AuthIdentityChannel, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{authidentitychannel.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) FirstX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AuthIdentityChannel ID from the query.
+// Returns a *NotFoundError when no AuthIdentityChannel ID was found.
+func (_q *AuthIdentityChannelQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{authidentitychannel.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AuthIdentityChannel entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AuthIdentityChannel entity is found.
+// Returns a *NotFoundError when no AuthIdentityChannel entities are found.
+func (_q *AuthIdentityChannelQuery) Only(ctx context.Context) (*AuthIdentityChannel, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{authidentitychannel.Label}
+ default:
+ return nil, &NotSingularError{authidentitychannel.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) OnlyX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AuthIdentityChannel ID in the query.
+// Returns a *NotSingularError when more than one AuthIdentityChannel ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *AuthIdentityChannelQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{authidentitychannel.Label}
+ default:
+ err = &NotSingularError{authidentitychannel.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AuthIdentityChannels.
+func (_q *AuthIdentityChannelQuery) All(ctx context.Context) ([]*AuthIdentityChannel, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AuthIdentityChannel, *AuthIdentityChannelQuery]()
+ return withInterceptors[[]*AuthIdentityChannel](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) AllX(ctx context.Context) []*AuthIdentityChannel {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AuthIdentityChannel IDs.
+func (_q *AuthIdentityChannelQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(authidentitychannel.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *AuthIdentityChannelQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityChannelQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *AuthIdentityChannelQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AuthIdentityChannelQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *AuthIdentityChannelQuery) Clone() *AuthIdentityChannelQuery {
+ if _q == nil {
+ return nil
+ }
+ return &AuthIdentityChannelQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]authidentitychannel.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.AuthIdentityChannel{}, _q.predicates...),
+ withIdentity: _q.withIdentity.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithIdentity tells the query-builder to eager-load the nodes that are connected to
+// the "identity" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityChannelQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *AuthIdentityChannelQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withIdentity = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AuthIdentityChannel.Query().
+// GroupBy(authidentitychannel.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *AuthIdentityChannelQuery) GroupBy(field string, fields ...string) *AuthIdentityChannelGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AuthIdentityChannelGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = authidentitychannel.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.AuthIdentityChannel.Query().
+// Select(authidentitychannel.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *AuthIdentityChannelQuery) Select(fields ...string) *AuthIdentityChannelSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &AuthIdentityChannelSelect{AuthIdentityChannelQuery: _q}
+ sbuild.label = authidentitychannel.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AuthIdentityChannelSelect configured with the given aggregations.
+func (_q *AuthIdentityChannelQuery) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *AuthIdentityChannelQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !authidentitychannel.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *AuthIdentityChannelQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentityChannel, error) {
+ var (
+ nodes = []*AuthIdentityChannel{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withIdentity != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AuthIdentityChannel).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AuthIdentityChannel{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withIdentity; query != nil {
+ if err := _q.loadIdentity(ctx, query, nodes, nil,
+ func(n *AuthIdentityChannel, e *AuthIdentity) { n.Edges.Identity = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *AuthIdentityChannelQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*AuthIdentityChannel, init func(*AuthIdentityChannel), assign func(*AuthIdentityChannel, *AuthIdentity)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*AuthIdentityChannel)
+ for i := range nodes {
+ fk := nodes[i].IdentityID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(authidentity.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *AuthIdentityChannelQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *AuthIdentityChannelQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID)
+ for i := range fields {
+ if fields[i] != authidentitychannel.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withIdentity != nil {
+ _spec.Node.AddColumnOnce(authidentitychannel.FieldIdentityID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *AuthIdentityChannelQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(authidentitychannel.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = authidentitychannel.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *AuthIdentityChannelQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityChannelQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *AuthIdentityChannelQuery) ForShare(opts ...sql.LockOption) *AuthIdentityChannelQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// AuthIdentityChannelGroupBy is the group-by builder for AuthIdentityChannel entities.
+type AuthIdentityChannelGroupBy struct {
+ selector
+ build *AuthIdentityChannelQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *AuthIdentityChannelGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *AuthIdentityChannelGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *AuthIdentityChannelGroupBy) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AuthIdentityChannelSelect is the builder for selecting fields of AuthIdentityChannel entities.
+type AuthIdentityChannelSelect struct {
+ *AuthIdentityChannelQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *AuthIdentityChannelSelect) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *AuthIdentityChannelSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelSelect](ctx, _s.AuthIdentityChannelQuery, _s, _s.inters, v)
+}
+
+func (_s *AuthIdentityChannelSelect) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/authidentitychannel_update.go b/backend/ent/authidentitychannel_update.go
new file mode 100644
index 00000000..b550c454
--- /dev/null
+++ b/backend/ent/authidentitychannel_update.go
@@ -0,0 +1,581 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelUpdate is the builder for updating AuthIdentityChannel entities.
+type AuthIdentityChannelUpdate struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// Where appends a list predicates to the AuthIdentityChannelUpdate builder.
+func (_u *AuthIdentityChannelUpdate) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityChannelUpdate) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *AuthIdentityChannelUpdate) SetIdentityID(v int64) *AuthIdentityChannelUpdate {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityChannelUpdate) SetProviderType(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableProviderType(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityChannelUpdate) SetProviderKey(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetChannel sets the "channel" field.
+func (_u *AuthIdentityChannelUpdate) SetChannel(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannel(v)
+ return _u
+}
+
+// SetNillableChannel sets the "channel" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannel(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannel(*v)
+ }
+ return _u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_u *AuthIdentityChannelUpdate) SetChannelAppID(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannelAppID(v)
+ return _u
+}
+
+// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannelAppID(*v)
+ }
+ return _u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_u *AuthIdentityChannelUpdate) SetChannelSubject(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannelSubject(v)
+ return _u
+}
+
+// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannelSubject(*v)
+ }
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityChannelUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdate {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdate {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_u *AuthIdentityChannelUpdate) Mutation() *AuthIdentityChannelMutation {
+ return _u.mutation
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdate) ClearIdentity() *AuthIdentityChannelUpdate {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *AuthIdentityChannelUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *AuthIdentityChannelUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityChannelUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityChannelUpdate) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityChannelUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentitychannel.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// AuthIdentityChannelUpdateOne is the builder for updating a single AuthIdentityChannel entity.
+type AuthIdentityChannelUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityChannelUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *AuthIdentityChannelUpdateOne) SetIdentityID(v int64) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityChannelUpdateOne) SetProviderType(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderType(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityChannelUpdateOne) SetProviderKey(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetChannel sets the "channel" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannel(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannel(v)
+ return _u
+}
+
+// SetNillableChannel sets the "channel" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannel(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannel(*v)
+ }
+ return _u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannelAppID(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannelAppID(v)
+ return _u
+}
+
+// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannelAppID(*v)
+ }
+ return _u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannelSubject(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannelSubject(v)
+ return _u
+}
+
+// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannelSubject(*v)
+ }
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityChannelUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdateOne) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdateOne {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_u *AuthIdentityChannelUpdateOne) Mutation() *AuthIdentityChannelMutation {
+ return _u.mutation
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdateOne) ClearIdentity() *AuthIdentityChannelUpdateOne {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Where appends a list predicates to the AuthIdentityChannelUpdate builder.
+func (_u *AuthIdentityChannelUpdateOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *AuthIdentityChannelUpdateOne) Select(field string, fields ...string) *AuthIdentityChannelUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated AuthIdentityChannel entity.
+func (_u *AuthIdentityChannelUpdateOne) Save(ctx context.Context) (*AuthIdentityChannel, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdateOne) SaveX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *AuthIdentityChannelUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityChannelUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityChannelUpdateOne) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityChannelUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentityChannel, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentityChannel.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID)
+ for _, f := range fields {
+ if !authidentitychannel.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != authidentitychannel.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AuthIdentityChannel{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentitychannel.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/channelmonitor.go b/backend/ent/channelmonitor.go
new file mode 100644
index 00000000..dbb73362
--- /dev/null
+++ b/backend/ent/channelmonitor.go
@@ -0,0 +1,359 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+)
+
+// ChannelMonitor is the model entity for the ChannelMonitor schema.
+type ChannelMonitor struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // Name holds the value of the "name" field.
+ Name string `json:"name,omitempty"`
+ // Provider holds the value of the "provider" field.
+ Provider channelmonitor.Provider `json:"provider,omitempty"`
+ // Provider base origin, e.g. https://api.openai.com
+ Endpoint string `json:"endpoint,omitempty"`
+ // AES-256-GCM encrypted API key
+ APIKeyEncrypted string `json:"-"`
+ // PrimaryModel holds the value of the "primary_model" field.
+ PrimaryModel string `json:"primary_model,omitempty"`
+ // Additional model names to test alongside primary_model
+ ExtraModels []string `json:"extra_models,omitempty"`
+ // GroupName holds the value of the "group_name" field.
+ GroupName string `json:"group_name,omitempty"`
+ // Enabled holds the value of the "enabled" field.
+ Enabled bool `json:"enabled,omitempty"`
+ // IntervalSeconds holds the value of the "interval_seconds" field.
+ IntervalSeconds int `json:"interval_seconds,omitempty"`
+ // LastCheckedAt holds the value of the "last_checked_at" field.
+ LastCheckedAt *time.Time `json:"last_checked_at,omitempty"`
+ // CreatedBy holds the value of the "created_by" field.
+ CreatedBy int64 `json:"created_by,omitempty"`
+ // TemplateID holds the value of the "template_id" field.
+ TemplateID *int64 `json:"template_id,omitempty"`
+ // ExtraHeaders holds the value of the "extra_headers" field.
+ ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
+ // BodyOverrideMode holds the value of the "body_override_mode" field.
+ BodyOverrideMode string `json:"body_override_mode,omitempty"`
+ // BodyOverride holds the value of the "body_override" field.
+ BodyOverride map[string]interface{} `json:"body_override,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the ChannelMonitorQuery when eager-loading is set.
+ Edges ChannelMonitorEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// ChannelMonitorEdges holds the relations/edges for other nodes in the graph.
+type ChannelMonitorEdges struct {
+ // History holds the value of the history edge.
+ History []*ChannelMonitorHistory `json:"history,omitempty"`
+ // DailyRollups holds the value of the daily_rollups edge.
+ DailyRollups []*ChannelMonitorDailyRollup `json:"daily_rollups,omitempty"`
+ // RequestTemplate holds the value of the request_template edge.
+ RequestTemplate *ChannelMonitorRequestTemplate `json:"request_template,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [3]bool
+}
+
+// HistoryOrErr returns the History value or an error if the edge
+// was not loaded in eager-loading.
+func (e ChannelMonitorEdges) HistoryOrErr() ([]*ChannelMonitorHistory, error) {
+ if e.loadedTypes[0] {
+ return e.History, nil
+ }
+ return nil, &NotLoadedError{edge: "history"}
+}
+
+// DailyRollupsOrErr returns the DailyRollups value or an error if the edge
+// was not loaded in eager-loading.
+func (e ChannelMonitorEdges) DailyRollupsOrErr() ([]*ChannelMonitorDailyRollup, error) {
+ if e.loadedTypes[1] {
+ return e.DailyRollups, nil
+ }
+ return nil, &NotLoadedError{edge: "daily_rollups"}
+}
+
+// RequestTemplateOrErr returns the RequestTemplate value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e ChannelMonitorEdges) RequestTemplateOrErr() (*ChannelMonitorRequestTemplate, error) {
+ if e.RequestTemplate != nil {
+ return e.RequestTemplate, nil
+ } else if e.loadedTypes[2] {
+ return nil, &NotFoundError{label: channelmonitorrequesttemplate.Label}
+ }
+ return nil, &NotLoadedError{edge: "request_template"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*ChannelMonitor) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitor.FieldExtraModels, channelmonitor.FieldExtraHeaders, channelmonitor.FieldBodyOverride:
+ values[i] = new([]byte)
+ case channelmonitor.FieldEnabled:
+ values[i] = new(sql.NullBool)
+ case channelmonitor.FieldID, channelmonitor.FieldIntervalSeconds, channelmonitor.FieldCreatedBy, channelmonitor.FieldTemplateID:
+ values[i] = new(sql.NullInt64)
+ case channelmonitor.FieldName, channelmonitor.FieldProvider, channelmonitor.FieldEndpoint, channelmonitor.FieldAPIKeyEncrypted, channelmonitor.FieldPrimaryModel, channelmonitor.FieldGroupName, channelmonitor.FieldBodyOverrideMode:
+ values[i] = new(sql.NullString)
+ case channelmonitor.FieldCreatedAt, channelmonitor.FieldUpdatedAt, channelmonitor.FieldLastCheckedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the ChannelMonitor fields.
+func (_m *ChannelMonitor) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitor.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case channelmonitor.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case channelmonitor.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case channelmonitor.FieldName:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field name", values[i])
+ } else if value.Valid {
+ _m.Name = value.String
+ }
+ case channelmonitor.FieldProvider:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider", values[i])
+ } else if value.Valid {
+ _m.Provider = channelmonitor.Provider(value.String)
+ }
+ case channelmonitor.FieldEndpoint:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field endpoint", values[i])
+ } else if value.Valid {
+ _m.Endpoint = value.String
+ }
+ case channelmonitor.FieldAPIKeyEncrypted:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field api_key_encrypted", values[i])
+ } else if value.Valid {
+ _m.APIKeyEncrypted = value.String
+ }
+ case channelmonitor.FieldPrimaryModel:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field primary_model", values[i])
+ } else if value.Valid {
+ _m.PrimaryModel = value.String
+ }
+ case channelmonitor.FieldExtraModels:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field extra_models", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ExtraModels); err != nil {
+ return fmt.Errorf("unmarshal field extra_models: %w", err)
+ }
+ }
+ case channelmonitor.FieldGroupName:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field group_name", values[i])
+ } else if value.Valid {
+ _m.GroupName = value.String
+ }
+ case channelmonitor.FieldEnabled:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field enabled", values[i])
+ } else if value.Valid {
+ _m.Enabled = value.Bool
+ }
+ case channelmonitor.FieldIntervalSeconds:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field interval_seconds", values[i])
+ } else if value.Valid {
+ _m.IntervalSeconds = int(value.Int64)
+ }
+ case channelmonitor.FieldLastCheckedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field last_checked_at", values[i])
+ } else if value.Valid {
+ _m.LastCheckedAt = new(time.Time)
+ *_m.LastCheckedAt = value.Time
+ }
+ case channelmonitor.FieldCreatedBy:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field created_by", values[i])
+ } else if value.Valid {
+ _m.CreatedBy = value.Int64
+ }
+ case channelmonitor.FieldTemplateID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field template_id", values[i])
+ } else if value.Valid {
+ _m.TemplateID = new(int64)
+ *_m.TemplateID = value.Int64
+ }
+ case channelmonitor.FieldExtraHeaders:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field extra_headers", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ExtraHeaders); err != nil {
+ return fmt.Errorf("unmarshal field extra_headers: %w", err)
+ }
+ }
+ case channelmonitor.FieldBodyOverrideMode:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field body_override_mode", values[i])
+ } else if value.Valid {
+ _m.BodyOverrideMode = value.String
+ }
+ case channelmonitor.FieldBodyOverride:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field body_override", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.BodyOverride); err != nil {
+ return fmt.Errorf("unmarshal field body_override: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitor.
+// This includes values selected through modifiers, order, etc.
+func (_m *ChannelMonitor) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryHistory queries the "history" edge of the ChannelMonitor entity.
+func (_m *ChannelMonitor) QueryHistory() *ChannelMonitorHistoryQuery {
+ return NewChannelMonitorClient(_m.config).QueryHistory(_m)
+}
+
+// QueryDailyRollups queries the "daily_rollups" edge of the ChannelMonitor entity.
+func (_m *ChannelMonitor) QueryDailyRollups() *ChannelMonitorDailyRollupQuery {
+ return NewChannelMonitorClient(_m.config).QueryDailyRollups(_m)
+}
+
+// QueryRequestTemplate queries the "request_template" edge of the ChannelMonitor entity.
+func (_m *ChannelMonitor) QueryRequestTemplate() *ChannelMonitorRequestTemplateQuery {
+ return NewChannelMonitorClient(_m.config).QueryRequestTemplate(_m)
+}
+
+// Update returns a builder for updating this ChannelMonitor.
+// Note that you need to call ChannelMonitor.Unwrap() before calling this method if this ChannelMonitor
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *ChannelMonitor) Update() *ChannelMonitorUpdateOne {
+ return NewChannelMonitorClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the ChannelMonitor entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *ChannelMonitor) Unwrap() *ChannelMonitor {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: ChannelMonitor is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *ChannelMonitor) String() string {
+ var builder strings.Builder
+ builder.WriteString("ChannelMonitor(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("name=")
+ builder.WriteString(_m.Name)
+ builder.WriteString(", ")
+ builder.WriteString("provider=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Provider))
+ builder.WriteString(", ")
+ builder.WriteString("endpoint=")
+ builder.WriteString(_m.Endpoint)
+ builder.WriteString(", ")
+ builder.WriteString("api_key_encrypted=")
+ builder.WriteString(", ")
+ builder.WriteString("primary_model=")
+ builder.WriteString(_m.PrimaryModel)
+ builder.WriteString(", ")
+ builder.WriteString("extra_models=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ExtraModels))
+ builder.WriteString(", ")
+ builder.WriteString("group_name=")
+ builder.WriteString(_m.GroupName)
+ builder.WriteString(", ")
+ builder.WriteString("enabled=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Enabled))
+ builder.WriteString(", ")
+ builder.WriteString("interval_seconds=")
+ builder.WriteString(fmt.Sprintf("%v", _m.IntervalSeconds))
+ builder.WriteString(", ")
+ if v := _m.LastCheckedAt; v != nil {
+ builder.WriteString("last_checked_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("created_by=")
+ builder.WriteString(fmt.Sprintf("%v", _m.CreatedBy))
+ builder.WriteString(", ")
+ if v := _m.TemplateID; v != nil {
+ builder.WriteString("template_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("extra_headers=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ExtraHeaders))
+ builder.WriteString(", ")
+ builder.WriteString("body_override_mode=")
+ builder.WriteString(_m.BodyOverrideMode)
+ builder.WriteString(", ")
+ builder.WriteString("body_override=")
+ builder.WriteString(fmt.Sprintf("%v", _m.BodyOverride))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// ChannelMonitors is a parsable slice of ChannelMonitor.
+type ChannelMonitors []*ChannelMonitor
diff --git a/backend/ent/channelmonitor/channelmonitor.go b/backend/ent/channelmonitor/channelmonitor.go
new file mode 100644
index 00000000..e5a6bfe7
--- /dev/null
+++ b/backend/ent/channelmonitor/channelmonitor.go
@@ -0,0 +1,304 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitor
+
+import (
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the channelmonitor type in the database.
+ Label = "channel_monitor"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldName holds the string denoting the name field in the database.
+ FieldName = "name"
+ // FieldProvider holds the string denoting the provider field in the database.
+ FieldProvider = "provider"
+ // FieldEndpoint holds the string denoting the endpoint field in the database.
+ FieldEndpoint = "endpoint"
+ // FieldAPIKeyEncrypted holds the string denoting the api_key_encrypted field in the database.
+ FieldAPIKeyEncrypted = "api_key_encrypted"
+ // FieldPrimaryModel holds the string denoting the primary_model field in the database.
+ FieldPrimaryModel = "primary_model"
+ // FieldExtraModels holds the string denoting the extra_models field in the database.
+ FieldExtraModels = "extra_models"
+ // FieldGroupName holds the string denoting the group_name field in the database.
+ FieldGroupName = "group_name"
+ // FieldEnabled holds the string denoting the enabled field in the database.
+ FieldEnabled = "enabled"
+ // FieldIntervalSeconds holds the string denoting the interval_seconds field in the database.
+ FieldIntervalSeconds = "interval_seconds"
+ // FieldLastCheckedAt holds the string denoting the last_checked_at field in the database.
+ FieldLastCheckedAt = "last_checked_at"
+ // FieldCreatedBy holds the string denoting the created_by field in the database.
+ FieldCreatedBy = "created_by"
+ // FieldTemplateID holds the string denoting the template_id field in the database.
+ FieldTemplateID = "template_id"
+ // FieldExtraHeaders holds the string denoting the extra_headers field in the database.
+ FieldExtraHeaders = "extra_headers"
+ // FieldBodyOverrideMode holds the string denoting the body_override_mode field in the database.
+ FieldBodyOverrideMode = "body_override_mode"
+ // FieldBodyOverride holds the string denoting the body_override field in the database.
+ FieldBodyOverride = "body_override"
+ // EdgeHistory holds the string denoting the history edge name in mutations.
+ EdgeHistory = "history"
+ // EdgeDailyRollups holds the string denoting the daily_rollups edge name in mutations.
+ EdgeDailyRollups = "daily_rollups"
+ // EdgeRequestTemplate holds the string denoting the request_template edge name in mutations.
+ EdgeRequestTemplate = "request_template"
+ // Table holds the table name of the channelmonitor in the database.
+ Table = "channel_monitors"
+ // HistoryTable is the table that holds the history relation/edge.
+ HistoryTable = "channel_monitor_histories"
+ // HistoryInverseTable is the table name for the ChannelMonitorHistory entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitorhistory" package.
+ HistoryInverseTable = "channel_monitor_histories"
+ // HistoryColumn is the table column denoting the history relation/edge.
+ HistoryColumn = "monitor_id"
+ // DailyRollupsTable is the table that holds the daily_rollups relation/edge.
+ DailyRollupsTable = "channel_monitor_daily_rollups"
+ // DailyRollupsInverseTable is the table name for the ChannelMonitorDailyRollup entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitordailyrollup" package.
+ DailyRollupsInverseTable = "channel_monitor_daily_rollups"
+ // DailyRollupsColumn is the table column denoting the daily_rollups relation/edge.
+ DailyRollupsColumn = "monitor_id"
+ // RequestTemplateTable is the table that holds the request_template relation/edge.
+ RequestTemplateTable = "channel_monitors"
+ // RequestTemplateInverseTable is the table name for the ChannelMonitorRequestTemplate entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitorrequesttemplate" package.
+ RequestTemplateInverseTable = "channel_monitor_request_templates"
+ // RequestTemplateColumn is the table column denoting the request_template relation/edge.
+ RequestTemplateColumn = "template_id"
+)
+
+// Columns holds all SQL columns for channelmonitor fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldName,
+ FieldProvider,
+ FieldEndpoint,
+ FieldAPIKeyEncrypted,
+ FieldPrimaryModel,
+ FieldExtraModels,
+ FieldGroupName,
+ FieldEnabled,
+ FieldIntervalSeconds,
+ FieldLastCheckedAt,
+ FieldCreatedBy,
+ FieldTemplateID,
+ FieldExtraHeaders,
+ FieldBodyOverrideMode,
+ FieldBodyOverride,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // NameValidator is a validator for the "name" field. It is called by the builders before save.
+ NameValidator func(string) error
+ // EndpointValidator is a validator for the "endpoint" field. It is called by the builders before save.
+ EndpointValidator func(string) error
+ // APIKeyEncryptedValidator is a validator for the "api_key_encrypted" field. It is called by the builders before save.
+ APIKeyEncryptedValidator func(string) error
+ // PrimaryModelValidator is a validator for the "primary_model" field. It is called by the builders before save.
+ PrimaryModelValidator func(string) error
+ // DefaultExtraModels holds the default value on creation for the "extra_models" field.
+ DefaultExtraModels []string
+ // DefaultGroupName holds the default value on creation for the "group_name" field.
+ DefaultGroupName string
+ // GroupNameValidator is a validator for the "group_name" field. It is called by the builders before save.
+ GroupNameValidator func(string) error
+ // DefaultEnabled holds the default value on creation for the "enabled" field.
+ DefaultEnabled bool
+ // IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save.
+ IntervalSecondsValidator func(int) error
+ // DefaultExtraHeaders holds the default value on creation for the "extra_headers" field.
+ DefaultExtraHeaders map[string]string
+ // DefaultBodyOverrideMode holds the default value on creation for the "body_override_mode" field.
+ DefaultBodyOverrideMode string
+ // BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
+ BodyOverrideModeValidator func(string) error
+)
+
+// Provider defines the type for the "provider" enum field.
+type Provider string
+
+// Provider values.
+const (
+ ProviderOpenai Provider = "openai"
+ ProviderAnthropic Provider = "anthropic"
+ ProviderGemini Provider = "gemini"
+)
+
+func (pr Provider) String() string {
+ return string(pr)
+}
+
+// ProviderValidator is a validator for the "provider" field enum values. It is called by the builders before save.
+func ProviderValidator(pr Provider) error {
+ switch pr {
+ case ProviderOpenai, ProviderAnthropic, ProviderGemini:
+ return nil
+ default:
+ return fmt.Errorf("channelmonitor: invalid enum value for provider field: %q", pr)
+ }
+}
+
+// OrderOption defines the ordering options for the ChannelMonitor queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByName orders the results by the name field.
+func ByName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldName, opts...).ToFunc()
+}
+
+// ByProvider orders the results by the provider field.
+func ByProvider(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProvider, opts...).ToFunc()
+}
+
+// ByEndpoint orders the results by the endpoint field.
+func ByEndpoint(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEndpoint, opts...).ToFunc()
+}
+
+// ByAPIKeyEncrypted orders the results by the api_key_encrypted field.
+func ByAPIKeyEncrypted(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAPIKeyEncrypted, opts...).ToFunc()
+}
+
+// ByPrimaryModel orders the results by the primary_model field.
+func ByPrimaryModel(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPrimaryModel, opts...).ToFunc()
+}
+
+// ByGroupName orders the results by the group_name field.
+func ByGroupName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldGroupName, opts...).ToFunc()
+}
+
+// ByEnabled orders the results by the enabled field.
+func ByEnabled(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEnabled, opts...).ToFunc()
+}
+
+// ByIntervalSeconds orders the results by the interval_seconds field.
+func ByIntervalSeconds(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIntervalSeconds, opts...).ToFunc()
+}
+
+// ByLastCheckedAt orders the results by the last_checked_at field.
+func ByLastCheckedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLastCheckedAt, opts...).ToFunc()
+}
+
+// ByCreatedBy orders the results by the created_by field.
+func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedBy, opts...).ToFunc()
+}
+
+// ByTemplateID orders the results by the template_id field.
+func ByTemplateID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTemplateID, opts...).ToFunc()
+}
+
+// ByBodyOverrideMode orders the results by the body_override_mode field.
+func ByBodyOverrideMode(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBodyOverrideMode, opts...).ToFunc()
+}
+
+// ByHistoryCount orders the results by history count.
+func ByHistoryCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newHistoryStep(), opts...)
+ }
+}
+
+// ByHistory orders the results by history terms.
+func ByHistory(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newHistoryStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByDailyRollupsCount orders the results by daily_rollups count.
+func ByDailyRollupsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newDailyRollupsStep(), opts...)
+ }
+}
+
+// ByDailyRollups orders the results by daily_rollups terms.
+func ByDailyRollups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newDailyRollupsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByRequestTemplateField orders the results by request_template field.
+func ByRequestTemplateField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newRequestTemplateStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newHistoryStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(HistoryInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, HistoryTable, HistoryColumn),
+ )
+}
+func newDailyRollupsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(DailyRollupsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, DailyRollupsTable, DailyRollupsColumn),
+ )
+}
+func newRequestTemplateStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(RequestTemplateInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, false, RequestTemplateTable, RequestTemplateColumn),
+ )
+}
diff --git a/backend/ent/channelmonitor/where.go b/backend/ent/channelmonitor/where.go
new file mode 100644
index 00000000..755d83a3
--- /dev/null
+++ b/backend/ent/channelmonitor/where.go
@@ -0,0 +1,885 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitor
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
+func Name(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldName, v))
+}
+
+// Endpoint applies equality check predicate on the "endpoint" field. It's identical to EndpointEQ.
+func Endpoint(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldEndpoint, v))
+}
+
+// APIKeyEncrypted applies equality check predicate on the "api_key_encrypted" field. It's identical to APIKeyEncryptedEQ.
+func APIKeyEncrypted(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldAPIKeyEncrypted, v))
+}
+
+// PrimaryModel applies equality check predicate on the "primary_model" field. It's identical to PrimaryModelEQ.
+func PrimaryModel(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldPrimaryModel, v))
+}
+
+// GroupName applies equality check predicate on the "group_name" field. It's identical to GroupNameEQ.
+func GroupName(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldGroupName, v))
+}
+
+// Enabled applies equality check predicate on the "enabled" field. It's identical to EnabledEQ.
+func Enabled(v bool) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldEnabled, v))
+}
+
+// IntervalSeconds applies equality check predicate on the "interval_seconds" field. It's identical to IntervalSecondsEQ.
+func IntervalSeconds(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldIntervalSeconds, v))
+}
+
+// LastCheckedAt applies equality check predicate on the "last_checked_at" field. It's identical to LastCheckedAtEQ.
+func LastCheckedAt(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldLastCheckedAt, v))
+}
+
+// CreatedBy applies equality check predicate on the "created_by" field. It's identical to CreatedByEQ.
+func CreatedBy(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedBy, v))
+}
+
+// TemplateID applies equality check predicate on the "template_id" field. It's identical to TemplateIDEQ.
+func TemplateID(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldTemplateID, v))
+}
+
+// BodyOverrideMode applies equality check predicate on the "body_override_mode" field. It's identical to BodyOverrideModeEQ.
+func BodyOverrideMode(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldBodyOverrideMode, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// NameEQ applies the EQ predicate on the "name" field.
+func NameEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldName, v))
+}
+
+// NameNEQ applies the NEQ predicate on the "name" field.
+func NameNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldName, v))
+}
+
+// NameIn applies the In predicate on the "name" field.
+func NameIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldName, vs...))
+}
+
+// NameNotIn applies the NotIn predicate on the "name" field.
+func NameNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldName, vs...))
+}
+
+// NameGT applies the GT predicate on the "name" field.
+func NameGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldName, v))
+}
+
+// NameGTE applies the GTE predicate on the "name" field.
+func NameGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldName, v))
+}
+
+// NameLT applies the LT predicate on the "name" field.
+func NameLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldName, v))
+}
+
+// NameLTE applies the LTE predicate on the "name" field.
+func NameLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldName, v))
+}
+
+// NameContains applies the Contains predicate on the "name" field.
+func NameContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldName, v))
+}
+
+// NameHasPrefix applies the HasPrefix predicate on the "name" field.
+func NameHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldName, v))
+}
+
+// NameHasSuffix applies the HasSuffix predicate on the "name" field.
+func NameHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldName, v))
+}
+
+// NameEqualFold applies the EqualFold predicate on the "name" field.
+func NameEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldName, v))
+}
+
+// NameContainsFold applies the ContainsFold predicate on the "name" field.
+func NameContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldName, v))
+}
+
+// ProviderEQ applies the EQ predicate on the "provider" field.
+func ProviderEQ(v Provider) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldProvider, v))
+}
+
+// ProviderNEQ applies the NEQ predicate on the "provider" field.
+func ProviderNEQ(v Provider) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldProvider, v))
+}
+
+// ProviderIn applies the In predicate on the "provider" field.
+func ProviderIn(vs ...Provider) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldProvider, vs...))
+}
+
+// ProviderNotIn applies the NotIn predicate on the "provider" field.
+func ProviderNotIn(vs ...Provider) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldProvider, vs...))
+}
+
+// EndpointEQ applies the EQ predicate on the "endpoint" field.
+func EndpointEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldEndpoint, v))
+}
+
+// EndpointNEQ applies the NEQ predicate on the "endpoint" field.
+func EndpointNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldEndpoint, v))
+}
+
+// EndpointIn applies the In predicate on the "endpoint" field.
+func EndpointIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldEndpoint, vs...))
+}
+
+// EndpointNotIn applies the NotIn predicate on the "endpoint" field.
+func EndpointNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldEndpoint, vs...))
+}
+
+// EndpointGT applies the GT predicate on the "endpoint" field.
+func EndpointGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldEndpoint, v))
+}
+
+// EndpointGTE applies the GTE predicate on the "endpoint" field.
+func EndpointGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldEndpoint, v))
+}
+
+// EndpointLT applies the LT predicate on the "endpoint" field.
+func EndpointLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldEndpoint, v))
+}
+
+// EndpointLTE applies the LTE predicate on the "endpoint" field.
+func EndpointLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldEndpoint, v))
+}
+
+// EndpointContains applies the Contains predicate on the "endpoint" field.
+func EndpointContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldEndpoint, v))
+}
+
+// EndpointHasPrefix applies the HasPrefix predicate on the "endpoint" field.
+func EndpointHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldEndpoint, v))
+}
+
+// EndpointHasSuffix applies the HasSuffix predicate on the "endpoint" field.
+func EndpointHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldEndpoint, v))
+}
+
+// EndpointEqualFold applies the EqualFold predicate on the "endpoint" field.
+func EndpointEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldEndpoint, v))
+}
+
+// EndpointContainsFold applies the ContainsFold predicate on the "endpoint" field.
+func EndpointContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldEndpoint, v))
+}
+
+// APIKeyEncryptedEQ applies the EQ predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedNEQ applies the NEQ predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedIn applies the In predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldAPIKeyEncrypted, vs...))
+}
+
+// APIKeyEncryptedNotIn applies the NotIn predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldAPIKeyEncrypted, vs...))
+}
+
+// APIKeyEncryptedGT applies the GT predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedGTE applies the GTE predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedLT applies the LT predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedLTE applies the LTE predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedContains applies the Contains predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedHasPrefix applies the HasPrefix predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedHasSuffix applies the HasSuffix predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedEqualFold applies the EqualFold predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedContainsFold applies the ContainsFold predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldAPIKeyEncrypted, v))
+}
+
+// PrimaryModelEQ applies the EQ predicate on the "primary_model" field.
+func PrimaryModelEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldPrimaryModel, v))
+}
+
+// PrimaryModelNEQ applies the NEQ predicate on the "primary_model" field.
+func PrimaryModelNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldPrimaryModel, v))
+}
+
+// PrimaryModelIn applies the In predicate on the "primary_model" field.
+func PrimaryModelIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldPrimaryModel, vs...))
+}
+
+// PrimaryModelNotIn applies the NotIn predicate on the "primary_model" field.
+func PrimaryModelNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldPrimaryModel, vs...))
+}
+
+// PrimaryModelGT applies the GT predicate on the "primary_model" field.
+func PrimaryModelGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldPrimaryModel, v))
+}
+
+// PrimaryModelGTE applies the GTE predicate on the "primary_model" field.
+func PrimaryModelGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldPrimaryModel, v))
+}
+
+// PrimaryModelLT applies the LT predicate on the "primary_model" field.
+func PrimaryModelLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldPrimaryModel, v))
+}
+
+// PrimaryModelLTE applies the LTE predicate on the "primary_model" field.
+func PrimaryModelLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldPrimaryModel, v))
+}
+
+// PrimaryModelContains applies the Contains predicate on the "primary_model" field.
+func PrimaryModelContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldPrimaryModel, v))
+}
+
+// PrimaryModelHasPrefix applies the HasPrefix predicate on the "primary_model" field.
+func PrimaryModelHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldPrimaryModel, v))
+}
+
+// PrimaryModelHasSuffix applies the HasSuffix predicate on the "primary_model" field.
+func PrimaryModelHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldPrimaryModel, v))
+}
+
+// PrimaryModelEqualFold applies the EqualFold predicate on the "primary_model" field.
+func PrimaryModelEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldPrimaryModel, v))
+}
+
+// PrimaryModelContainsFold applies the ContainsFold predicate on the "primary_model" field.
+func PrimaryModelContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldPrimaryModel, v))
+}
+
+// GroupNameEQ applies the EQ predicate on the "group_name" field.
+func GroupNameEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldGroupName, v))
+}
+
+// GroupNameNEQ applies the NEQ predicate on the "group_name" field.
+func GroupNameNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldGroupName, v))
+}
+
+// GroupNameIn applies the In predicate on the "group_name" field.
+func GroupNameIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldGroupName, vs...))
+}
+
+// GroupNameNotIn applies the NotIn predicate on the "group_name" field.
+func GroupNameNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldGroupName, vs...))
+}
+
+// GroupNameGT applies the GT predicate on the "group_name" field.
+func GroupNameGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldGroupName, v))
+}
+
+// GroupNameGTE applies the GTE predicate on the "group_name" field.
+func GroupNameGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldGroupName, v))
+}
+
+// GroupNameLT applies the LT predicate on the "group_name" field.
+func GroupNameLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldGroupName, v))
+}
+
+// GroupNameLTE applies the LTE predicate on the "group_name" field.
+func GroupNameLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldGroupName, v))
+}
+
+// GroupNameContains applies the Contains predicate on the "group_name" field.
+func GroupNameContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldGroupName, v))
+}
+
+// GroupNameHasPrefix applies the HasPrefix predicate on the "group_name" field.
+func GroupNameHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldGroupName, v))
+}
+
+// GroupNameHasSuffix applies the HasSuffix predicate on the "group_name" field.
+func GroupNameHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldGroupName, v))
+}
+
+// GroupNameIsNil applies the IsNil predicate on the "group_name" field.
+func GroupNameIsNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIsNull(FieldGroupName))
+}
+
+// GroupNameNotNil applies the NotNil predicate on the "group_name" field.
+func GroupNameNotNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotNull(FieldGroupName))
+}
+
+// GroupNameEqualFold applies the EqualFold predicate on the "group_name" field.
+func GroupNameEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldGroupName, v))
+}
+
+// GroupNameContainsFold applies the ContainsFold predicate on the "group_name" field.
+func GroupNameContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldGroupName, v))
+}
+
+// EnabledEQ applies the EQ predicate on the "enabled" field.
+func EnabledEQ(v bool) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldEnabled, v))
+}
+
+// EnabledNEQ applies the NEQ predicate on the "enabled" field.
+func EnabledNEQ(v bool) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldEnabled, v))
+}
+
+// IntervalSecondsEQ applies the EQ predicate on the "interval_seconds" field.
+func IntervalSecondsEQ(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldIntervalSeconds, v))
+}
+
+// IntervalSecondsNEQ applies the NEQ predicate on the "interval_seconds" field.
+func IntervalSecondsNEQ(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldIntervalSeconds, v))
+}
+
+// IntervalSecondsIn applies the In predicate on the "interval_seconds" field.
+func IntervalSecondsIn(vs ...int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldIntervalSeconds, vs...))
+}
+
+// IntervalSecondsNotIn applies the NotIn predicate on the "interval_seconds" field.
+func IntervalSecondsNotIn(vs ...int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldIntervalSeconds, vs...))
+}
+
+// IntervalSecondsGT applies the GT predicate on the "interval_seconds" field.
+func IntervalSecondsGT(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldIntervalSeconds, v))
+}
+
+// IntervalSecondsGTE applies the GTE predicate on the "interval_seconds" field.
+func IntervalSecondsGTE(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldIntervalSeconds, v))
+}
+
+// IntervalSecondsLT applies the LT predicate on the "interval_seconds" field.
+func IntervalSecondsLT(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldIntervalSeconds, v))
+}
+
+// IntervalSecondsLTE applies the LTE predicate on the "interval_seconds" field.
+func IntervalSecondsLTE(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldIntervalSeconds, v))
+}
+
+// LastCheckedAtEQ applies the EQ predicate on the "last_checked_at" field.
+func LastCheckedAtEQ(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldLastCheckedAt, v))
+}
+
+// LastCheckedAtNEQ applies the NEQ predicate on the "last_checked_at" field.
+func LastCheckedAtNEQ(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldLastCheckedAt, v))
+}
+
+// LastCheckedAtIn applies the In predicate on the "last_checked_at" field.
+func LastCheckedAtIn(vs ...time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldLastCheckedAt, vs...))
+}
+
+// LastCheckedAtNotIn applies the NotIn predicate on the "last_checked_at" field.
+func LastCheckedAtNotIn(vs ...time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldLastCheckedAt, vs...))
+}
+
+// LastCheckedAtGT applies the GT predicate on the "last_checked_at" field.
+func LastCheckedAtGT(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldLastCheckedAt, v))
+}
+
+// LastCheckedAtGTE applies the GTE predicate on the "last_checked_at" field.
+func LastCheckedAtGTE(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldLastCheckedAt, v))
+}
+
+// LastCheckedAtLT applies the LT predicate on the "last_checked_at" field.
+func LastCheckedAtLT(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldLastCheckedAt, v))
+}
+
+// LastCheckedAtLTE applies the LTE predicate on the "last_checked_at" field.
+func LastCheckedAtLTE(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldLastCheckedAt, v))
+}
+
+// LastCheckedAtIsNil applies the IsNil predicate on the "last_checked_at" field.
+func LastCheckedAtIsNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIsNull(FieldLastCheckedAt))
+}
+
+// LastCheckedAtNotNil applies the NotNil predicate on the "last_checked_at" field.
+func LastCheckedAtNotNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotNull(FieldLastCheckedAt))
+}
+
+// CreatedByEQ applies the EQ predicate on the "created_by" field.
+func CreatedByEQ(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedBy, v))
+}
+
+// CreatedByNEQ applies the NEQ predicate on the "created_by" field.
+func CreatedByNEQ(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldCreatedBy, v))
+}
+
+// CreatedByIn applies the In predicate on the "created_by" field.
+func CreatedByIn(vs ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldCreatedBy, vs...))
+}
+
+// CreatedByNotIn applies the NotIn predicate on the "created_by" field.
+func CreatedByNotIn(vs ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldCreatedBy, vs...))
+}
+
+// CreatedByGT applies the GT predicate on the "created_by" field.
+func CreatedByGT(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldCreatedBy, v))
+}
+
+// CreatedByGTE applies the GTE predicate on the "created_by" field.
+func CreatedByGTE(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldCreatedBy, v))
+}
+
+// CreatedByLT applies the LT predicate on the "created_by" field.
+func CreatedByLT(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldCreatedBy, v))
+}
+
+// CreatedByLTE applies the LTE predicate on the "created_by" field.
+func CreatedByLTE(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldCreatedBy, v))
+}
+
+// TemplateIDEQ applies the EQ predicate on the "template_id" field.
+func TemplateIDEQ(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldTemplateID, v))
+}
+
+// TemplateIDNEQ applies the NEQ predicate on the "template_id" field.
+func TemplateIDNEQ(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldTemplateID, v))
+}
+
+// TemplateIDIn applies the In predicate on the "template_id" field.
+func TemplateIDIn(vs ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldTemplateID, vs...))
+}
+
+// TemplateIDNotIn applies the NotIn predicate on the "template_id" field.
+func TemplateIDNotIn(vs ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldTemplateID, vs...))
+}
+
+// TemplateIDIsNil applies the IsNil predicate on the "template_id" field.
+func TemplateIDIsNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIsNull(FieldTemplateID))
+}
+
+// TemplateIDNotNil applies the NotNil predicate on the "template_id" field.
+func TemplateIDNotNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotNull(FieldTemplateID))
+}
+
+// BodyOverrideModeEQ applies the EQ predicate on the "body_override_mode" field.
+func BodyOverrideModeEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeNEQ applies the NEQ predicate on the "body_override_mode" field.
+func BodyOverrideModeNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeIn applies the In predicate on the "body_override_mode" field.
+func BodyOverrideModeIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldBodyOverrideMode, vs...))
+}
+
+// BodyOverrideModeNotIn applies the NotIn predicate on the "body_override_mode" field.
+func BodyOverrideModeNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldBodyOverrideMode, vs...))
+}
+
+// BodyOverrideModeGT applies the GT predicate on the "body_override_mode" field.
+func BodyOverrideModeGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeGTE applies the GTE predicate on the "body_override_mode" field.
+func BodyOverrideModeGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeLT applies the LT predicate on the "body_override_mode" field.
+func BodyOverrideModeLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeLTE applies the LTE predicate on the "body_override_mode" field.
+func BodyOverrideModeLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeContains applies the Contains predicate on the "body_override_mode" field.
+func BodyOverrideModeContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeHasPrefix applies the HasPrefix predicate on the "body_override_mode" field.
+func BodyOverrideModeHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeHasSuffix applies the HasSuffix predicate on the "body_override_mode" field.
+func BodyOverrideModeHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeEqualFold applies the EqualFold predicate on the "body_override_mode" field.
+func BodyOverrideModeEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeContainsFold applies the ContainsFold predicate on the "body_override_mode" field.
+func BodyOverrideModeContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideIsNil applies the IsNil predicate on the "body_override" field.
+func BodyOverrideIsNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIsNull(FieldBodyOverride))
+}
+
+// BodyOverrideNotNil applies the NotNil predicate on the "body_override" field.
+func BodyOverrideNotNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotNull(FieldBodyOverride))
+}
+
+// HasHistory applies the HasEdge predicate on the "history" edge.
+func HasHistory() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, HistoryTable, HistoryColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasHistoryWith applies the HasEdge predicate on the "history" edge with a given conditions (other predicates).
+func HasHistoryWith(preds ...predicate.ChannelMonitorHistory) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := newHistoryStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasDailyRollups applies the HasEdge predicate on the "daily_rollups" edge.
+func HasDailyRollups() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, DailyRollupsTable, DailyRollupsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasDailyRollupsWith applies the HasEdge predicate on the "daily_rollups" edge with a given conditions (other predicates).
+func HasDailyRollupsWith(preds ...predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := newDailyRollupsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasRequestTemplate applies the HasEdge predicate on the "request_template" edge.
+func HasRequestTemplate() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, false, RequestTemplateTable, RequestTemplateColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasRequestTemplateWith applies the HasEdge predicate on the "request_template" edge with a given conditions (other predicates).
+func HasRequestTemplateWith(preds ...predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := newRequestTemplateStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.ChannelMonitor) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.ChannelMonitor) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.ChannelMonitor) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.NotPredicates(p))
+}
diff --git a/backend/ent/channelmonitor_create.go b/backend/ent/channelmonitor_create.go
new file mode 100644
index 00000000..2f70c300
--- /dev/null
+++ b/backend/ent/channelmonitor_create.go
@@ -0,0 +1,1610 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+)
+
+// ChannelMonitorCreate is the builder for creating a ChannelMonitor entity.
+type ChannelMonitorCreate struct {
+ config
+ mutation *ChannelMonitorMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *ChannelMonitorCreate) SetCreatedAt(v time.Time) *ChannelMonitorCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableCreatedAt(v *time.Time) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *ChannelMonitorCreate) SetUpdatedAt(v time.Time) *ChannelMonitorCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableUpdatedAt(v *time.Time) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetName sets the "name" field.
+func (_c *ChannelMonitorCreate) SetName(v string) *ChannelMonitorCreate {
+ _c.mutation.SetName(v)
+ return _c
+}
+
+// SetProvider sets the "provider" field.
+func (_c *ChannelMonitorCreate) SetProvider(v channelmonitor.Provider) *ChannelMonitorCreate {
+ _c.mutation.SetProvider(v)
+ return _c
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (_c *ChannelMonitorCreate) SetEndpoint(v string) *ChannelMonitorCreate {
+ _c.mutation.SetEndpoint(v)
+ return _c
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (_c *ChannelMonitorCreate) SetAPIKeyEncrypted(v string) *ChannelMonitorCreate {
+ _c.mutation.SetAPIKeyEncrypted(v)
+ return _c
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (_c *ChannelMonitorCreate) SetPrimaryModel(v string) *ChannelMonitorCreate {
+ _c.mutation.SetPrimaryModel(v)
+ return _c
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (_c *ChannelMonitorCreate) SetExtraModels(v []string) *ChannelMonitorCreate {
+ _c.mutation.SetExtraModels(v)
+ return _c
+}
+
+// SetGroupName sets the "group_name" field.
+func (_c *ChannelMonitorCreate) SetGroupName(v string) *ChannelMonitorCreate {
+ _c.mutation.SetGroupName(v)
+ return _c
+}
+
+// SetNillableGroupName sets the "group_name" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableGroupName(v *string) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetGroupName(*v)
+ }
+ return _c
+}
+
+// SetEnabled sets the "enabled" field.
+func (_c *ChannelMonitorCreate) SetEnabled(v bool) *ChannelMonitorCreate {
+ _c.mutation.SetEnabled(v)
+ return _c
+}
+
+// SetNillableEnabled sets the "enabled" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableEnabled(v *bool) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetEnabled(*v)
+ }
+ return _c
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (_c *ChannelMonitorCreate) SetIntervalSeconds(v int) *ChannelMonitorCreate {
+ _c.mutation.SetIntervalSeconds(v)
+ return _c
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (_c *ChannelMonitorCreate) SetLastCheckedAt(v time.Time) *ChannelMonitorCreate {
+ _c.mutation.SetLastCheckedAt(v)
+ return _c
+}
+
+// SetNillableLastCheckedAt sets the "last_checked_at" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableLastCheckedAt(v *time.Time) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetLastCheckedAt(*v)
+ }
+ return _c
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (_c *ChannelMonitorCreate) SetCreatedBy(v int64) *ChannelMonitorCreate {
+ _c.mutation.SetCreatedBy(v)
+ return _c
+}
+
+// SetTemplateID sets the "template_id" field.
+func (_c *ChannelMonitorCreate) SetTemplateID(v int64) *ChannelMonitorCreate {
+ _c.mutation.SetTemplateID(v)
+ return _c
+}
+
+// SetNillableTemplateID sets the "template_id" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableTemplateID(v *int64) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetTemplateID(*v)
+ }
+ return _c
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_c *ChannelMonitorCreate) SetExtraHeaders(v map[string]string) *ChannelMonitorCreate {
+ _c.mutation.SetExtraHeaders(v)
+ return _c
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_c *ChannelMonitorCreate) SetBodyOverrideMode(v string) *ChannelMonitorCreate {
+ _c.mutation.SetBodyOverrideMode(v)
+ return _c
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetBodyOverrideMode(*v)
+ }
+ return _c
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_c *ChannelMonitorCreate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorCreate {
+ _c.mutation.SetBodyOverride(v)
+ return _c
+}
+
+// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by IDs.
+func (_c *ChannelMonitorCreate) AddHistoryIDs(ids ...int64) *ChannelMonitorCreate {
+ _c.mutation.AddHistoryIDs(ids...)
+ return _c
+}
+
+// AddHistory adds the "history" edges to the ChannelMonitorHistory entity.
+func (_c *ChannelMonitorCreate) AddHistory(v ...*ChannelMonitorHistory) *ChannelMonitorCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddHistoryIDs(ids...)
+}
+
+// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs.
+func (_c *ChannelMonitorCreate) AddDailyRollupIDs(ids ...int64) *ChannelMonitorCreate {
+ _c.mutation.AddDailyRollupIDs(ids...)
+ return _c
+}
+
+// AddDailyRollups adds the "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_c *ChannelMonitorCreate) AddDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddDailyRollupIDs(ids...)
+}
+
+// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID.
+func (_c *ChannelMonitorCreate) SetRequestTemplateID(id int64) *ChannelMonitorCreate {
+ _c.mutation.SetRequestTemplateID(id)
+ return _c
+}
+
+// SetNillableRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableRequestTemplateID(id *int64) *ChannelMonitorCreate {
+ if id != nil {
+ _c = _c.SetRequestTemplateID(*id)
+ }
+ return _c
+}
+
+// SetRequestTemplate sets the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_c *ChannelMonitorCreate) SetRequestTemplate(v *ChannelMonitorRequestTemplate) *ChannelMonitorCreate {
+ return _c.SetRequestTemplateID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorMutation object of the builder.
+func (_c *ChannelMonitorCreate) Mutation() *ChannelMonitorMutation {
+ return _c.mutation
+}
+
+// Save creates the ChannelMonitor in the database.
+func (_c *ChannelMonitorCreate) Save(ctx context.Context) (*ChannelMonitor, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *ChannelMonitorCreate) SaveX(ctx context.Context) *ChannelMonitor {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *ChannelMonitorCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := channelmonitor.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := channelmonitor.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.ExtraModels(); !ok {
+ v := channelmonitor.DefaultExtraModels
+ _c.mutation.SetExtraModels(v)
+ }
+ if _, ok := _c.mutation.GroupName(); !ok {
+ v := channelmonitor.DefaultGroupName
+ _c.mutation.SetGroupName(v)
+ }
+ if _, ok := _c.mutation.Enabled(); !ok {
+ v := channelmonitor.DefaultEnabled
+ _c.mutation.SetEnabled(v)
+ }
+ if _, ok := _c.mutation.ExtraHeaders(); !ok {
+ v := channelmonitor.DefaultExtraHeaders
+ _c.mutation.SetExtraHeaders(v)
+ }
+ if _, ok := _c.mutation.BodyOverrideMode(); !ok {
+ v := channelmonitor.DefaultBodyOverrideMode
+ _c.mutation.SetBodyOverrideMode(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *ChannelMonitorCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "ChannelMonitor.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "ChannelMonitor.updated_at"`)}
+ }
+ if _, ok := _c.mutation.Name(); !ok {
+ return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "ChannelMonitor.name"`)}
+ }
+ if v, ok := _c.mutation.Name(); ok {
+ if err := channelmonitor.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.name": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Provider(); !ok {
+ return &ValidationError{Name: "provider", err: errors.New(`ent: missing required field "ChannelMonitor.provider"`)}
+ }
+ if v, ok := _c.mutation.Provider(); ok {
+ if err := channelmonitor.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Endpoint(); !ok {
+ return &ValidationError{Name: "endpoint", err: errors.New(`ent: missing required field "ChannelMonitor.endpoint"`)}
+ }
+ if v, ok := _c.mutation.Endpoint(); ok {
+ if err := channelmonitor.EndpointValidator(v); err != nil {
+ return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.APIKeyEncrypted(); !ok {
+ return &ValidationError{Name: "api_key_encrypted", err: errors.New(`ent: missing required field "ChannelMonitor.api_key_encrypted"`)}
+ }
+ if v, ok := _c.mutation.APIKeyEncrypted(); ok {
+ if err := channelmonitor.APIKeyEncryptedValidator(v); err != nil {
+ return &ValidationError{Name: "api_key_encrypted", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_key_encrypted": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.PrimaryModel(); !ok {
+ return &ValidationError{Name: "primary_model", err: errors.New(`ent: missing required field "ChannelMonitor.primary_model"`)}
+ }
+ if v, ok := _c.mutation.PrimaryModel(); ok {
+ if err := channelmonitor.PrimaryModelValidator(v); err != nil {
+ return &ValidationError{Name: "primary_model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.primary_model": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ExtraModels(); !ok {
+ return &ValidationError{Name: "extra_models", err: errors.New(`ent: missing required field "ChannelMonitor.extra_models"`)}
+ }
+ if v, ok := _c.mutation.GroupName(); ok {
+ if err := channelmonitor.GroupNameValidator(v); err != nil {
+ return &ValidationError{Name: "group_name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.group_name": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Enabled(); !ok {
+ return &ValidationError{Name: "enabled", err: errors.New(`ent: missing required field "ChannelMonitor.enabled"`)}
+ }
+ if _, ok := _c.mutation.IntervalSeconds(); !ok {
+ return &ValidationError{Name: "interval_seconds", err: errors.New(`ent: missing required field "ChannelMonitor.interval_seconds"`)}
+ }
+ if v, ok := _c.mutation.IntervalSeconds(); ok {
+ if err := channelmonitor.IntervalSecondsValidator(v); err != nil {
+ return &ValidationError{Name: "interval_seconds", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.interval_seconds": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.CreatedBy(); !ok {
+ return &ValidationError{Name: "created_by", err: errors.New(`ent: missing required field "ChannelMonitor.created_by"`)}
+ }
+ if _, ok := _c.mutation.ExtraHeaders(); !ok {
+ return &ValidationError{Name: "extra_headers", err: errors.New(`ent: missing required field "ChannelMonitor.extra_headers"`)}
+ }
+ if _, ok := _c.mutation.BodyOverrideMode(); !ok {
+ return &ValidationError{Name: "body_override_mode", err: errors.New(`ent: missing required field "ChannelMonitor.body_override_mode"`)}
+ }
+ if v, ok := _c.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitor.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_c *ChannelMonitorCreate) sqlSave(ctx context.Context) (*ChannelMonitor, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *ChannelMonitorCreate) createSpec() (*ChannelMonitor, *sqlgraph.CreateSpec) {
+ var (
+ _node = &ChannelMonitor{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(channelmonitor.Table, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(channelmonitor.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitor.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.Name(); ok {
+ _spec.SetField(channelmonitor.FieldName, field.TypeString, value)
+ _node.Name = value
+ }
+ if value, ok := _c.mutation.Provider(); ok {
+ _spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
+ _node.Provider = value
+ }
+ if value, ok := _c.mutation.Endpoint(); ok {
+ _spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
+ _node.Endpoint = value
+ }
+ if value, ok := _c.mutation.APIKeyEncrypted(); ok {
+ _spec.SetField(channelmonitor.FieldAPIKeyEncrypted, field.TypeString, value)
+ _node.APIKeyEncrypted = value
+ }
+ if value, ok := _c.mutation.PrimaryModel(); ok {
+ _spec.SetField(channelmonitor.FieldPrimaryModel, field.TypeString, value)
+ _node.PrimaryModel = value
+ }
+ if value, ok := _c.mutation.ExtraModels(); ok {
+ _spec.SetField(channelmonitor.FieldExtraModels, field.TypeJSON, value)
+ _node.ExtraModels = value
+ }
+ if value, ok := _c.mutation.GroupName(); ok {
+ _spec.SetField(channelmonitor.FieldGroupName, field.TypeString, value)
+ _node.GroupName = value
+ }
+ if value, ok := _c.mutation.Enabled(); ok {
+ _spec.SetField(channelmonitor.FieldEnabled, field.TypeBool, value)
+ _node.Enabled = value
+ }
+ if value, ok := _c.mutation.IntervalSeconds(); ok {
+ _spec.SetField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value)
+ _node.IntervalSeconds = value
+ }
+ if value, ok := _c.mutation.LastCheckedAt(); ok {
+ _spec.SetField(channelmonitor.FieldLastCheckedAt, field.TypeTime, value)
+ _node.LastCheckedAt = &value
+ }
+ if value, ok := _c.mutation.CreatedBy(); ok {
+ _spec.SetField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
+ _node.CreatedBy = value
+ }
+ if value, ok := _c.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitor.FieldExtraHeaders, field.TypeJSON, value)
+ _node.ExtraHeaders = value
+ }
+ if value, ok := _c.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverrideMode, field.TypeString, value)
+ _node.BodyOverrideMode = value
+ }
+ if value, ok := _c.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverride, field.TypeJSON, value)
+ _node.BodyOverride = value
+ }
+ if nodes := _c.mutation.HistoryIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.DailyRollupsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.RequestTemplateIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.TemplateID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitor.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorUpsertOne {
+ _c.conflict = opts
+ return &ChannelMonitorUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitor.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorCreate) OnConflictColumns(columns ...string) *ChannelMonitorUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // ChannelMonitorUpsertOne is the builder for "upsert"-ing
+ // one ChannelMonitor node.
+ ChannelMonitorUpsertOne struct {
+ create *ChannelMonitorCreate
+ }
+
+ // ChannelMonitorUpsert is the "OnConflict" setter.
+ ChannelMonitorUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorUpsert) SetUpdatedAt(v time.Time) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateUpdatedAt() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldUpdatedAt)
+ return u
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorUpsert) SetName(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldName, v)
+ return u
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateName() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldName)
+ return u
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorUpsert) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldProvider, v)
+ return u
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateProvider() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldProvider)
+ return u
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (u *ChannelMonitorUpsert) SetEndpoint(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldEndpoint, v)
+ return u
+}
+
+// UpdateEndpoint sets the "endpoint" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateEndpoint() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldEndpoint)
+ return u
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (u *ChannelMonitorUpsert) SetAPIKeyEncrypted(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldAPIKeyEncrypted, v)
+ return u
+}
+
+// UpdateAPIKeyEncrypted sets the "api_key_encrypted" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateAPIKeyEncrypted() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldAPIKeyEncrypted)
+ return u
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (u *ChannelMonitorUpsert) SetPrimaryModel(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldPrimaryModel, v)
+ return u
+}
+
+// UpdatePrimaryModel sets the "primary_model" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdatePrimaryModel() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldPrimaryModel)
+ return u
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (u *ChannelMonitorUpsert) SetExtraModels(v []string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldExtraModels, v)
+ return u
+}
+
+// UpdateExtraModels sets the "extra_models" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateExtraModels() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldExtraModels)
+ return u
+}
+
+// SetGroupName sets the "group_name" field.
+func (u *ChannelMonitorUpsert) SetGroupName(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldGroupName, v)
+ return u
+}
+
+// UpdateGroupName sets the "group_name" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateGroupName() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldGroupName)
+ return u
+}
+
+// ClearGroupName clears the value of the "group_name" field.
+func (u *ChannelMonitorUpsert) ClearGroupName() *ChannelMonitorUpsert {
+ u.SetNull(channelmonitor.FieldGroupName)
+ return u
+}
+
+// SetEnabled sets the "enabled" field.
+func (u *ChannelMonitorUpsert) SetEnabled(v bool) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldEnabled, v)
+ return u
+}
+
+// UpdateEnabled sets the "enabled" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateEnabled() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldEnabled)
+ return u
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (u *ChannelMonitorUpsert) SetIntervalSeconds(v int) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldIntervalSeconds, v)
+ return u
+}
+
+// UpdateIntervalSeconds sets the "interval_seconds" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateIntervalSeconds() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldIntervalSeconds)
+ return u
+}
+
+// AddIntervalSeconds adds v to the "interval_seconds" field.
+func (u *ChannelMonitorUpsert) AddIntervalSeconds(v int) *ChannelMonitorUpsert {
+ u.Add(channelmonitor.FieldIntervalSeconds, v)
+ return u
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (u *ChannelMonitorUpsert) SetLastCheckedAt(v time.Time) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldLastCheckedAt, v)
+ return u
+}
+
+// UpdateLastCheckedAt sets the "last_checked_at" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateLastCheckedAt() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldLastCheckedAt)
+ return u
+}
+
+// ClearLastCheckedAt clears the value of the "last_checked_at" field.
+func (u *ChannelMonitorUpsert) ClearLastCheckedAt() *ChannelMonitorUpsert {
+ u.SetNull(channelmonitor.FieldLastCheckedAt)
+ return u
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (u *ChannelMonitorUpsert) SetCreatedBy(v int64) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldCreatedBy, v)
+ return u
+}
+
+// UpdateCreatedBy sets the "created_by" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateCreatedBy() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldCreatedBy)
+ return u
+}
+
+// AddCreatedBy adds v to the "created_by" field.
+func (u *ChannelMonitorUpsert) AddCreatedBy(v int64) *ChannelMonitorUpsert {
+ u.Add(channelmonitor.FieldCreatedBy, v)
+ return u
+}
+
+// SetTemplateID sets the "template_id" field.
+func (u *ChannelMonitorUpsert) SetTemplateID(v int64) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldTemplateID, v)
+ return u
+}
+
+// UpdateTemplateID sets the "template_id" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateTemplateID() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldTemplateID)
+ return u
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (u *ChannelMonitorUpsert) ClearTemplateID() *ChannelMonitorUpsert {
+ u.SetNull(channelmonitor.FieldTemplateID)
+ return u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorUpsert) SetExtraHeaders(v map[string]string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldExtraHeaders, v)
+ return u
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateExtraHeaders() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldExtraHeaders)
+ return u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorUpsert) SetBodyOverrideMode(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldBodyOverrideMode, v)
+ return u
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateBodyOverrideMode() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldBodyOverrideMode)
+ return u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorUpsert) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldBodyOverride, v)
+ return u
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateBodyOverride() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldBodyOverride)
+ return u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorUpsert) ClearBodyOverride() *ChannelMonitorUpsert {
+ u.SetNull(channelmonitor.FieldBodyOverride)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitor.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorUpsertOne) UpdateNewValues() *ChannelMonitorUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(channelmonitor.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitor.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorUpsertOne) Ignore() *ChannelMonitorUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorUpsertOne) DoNothing() *ChannelMonitorUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorCreate.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorUpsertOne) Update(set func(*ChannelMonitorUpsert)) *ChannelMonitorUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorUpsertOne) SetUpdatedAt(v time.Time) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateUpdatedAt() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorUpsertOne) SetName(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateName() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorUpsertOne) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetProvider(v)
+ })
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateProvider() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateProvider()
+ })
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (u *ChannelMonitorUpsertOne) SetEndpoint(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetEndpoint(v)
+ })
+}
+
+// UpdateEndpoint sets the "endpoint" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateEndpoint() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateEndpoint()
+ })
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (u *ChannelMonitorUpsertOne) SetAPIKeyEncrypted(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetAPIKeyEncrypted(v)
+ })
+}
+
+// UpdateAPIKeyEncrypted sets the "api_key_encrypted" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateAPIKeyEncrypted() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateAPIKeyEncrypted()
+ })
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (u *ChannelMonitorUpsertOne) SetPrimaryModel(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetPrimaryModel(v)
+ })
+}
+
+// UpdatePrimaryModel sets the "primary_model" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdatePrimaryModel() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdatePrimaryModel()
+ })
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (u *ChannelMonitorUpsertOne) SetExtraModels(v []string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetExtraModels(v)
+ })
+}
+
+// UpdateExtraModels sets the "extra_models" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateExtraModels() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateExtraModels()
+ })
+}
+
+// SetGroupName sets the "group_name" field.
+func (u *ChannelMonitorUpsertOne) SetGroupName(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetGroupName(v)
+ })
+}
+
+// UpdateGroupName sets the "group_name" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateGroupName() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateGroupName()
+ })
+}
+
+// ClearGroupName clears the value of the "group_name" field.
+func (u *ChannelMonitorUpsertOne) ClearGroupName() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearGroupName()
+ })
+}
+
+// SetEnabled sets the "enabled" field.
+func (u *ChannelMonitorUpsertOne) SetEnabled(v bool) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetEnabled(v)
+ })
+}
+
+// UpdateEnabled sets the "enabled" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateEnabled() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateEnabled()
+ })
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (u *ChannelMonitorUpsertOne) SetIntervalSeconds(v int) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetIntervalSeconds(v)
+ })
+}
+
+// AddIntervalSeconds adds v to the "interval_seconds" field.
+func (u *ChannelMonitorUpsertOne) AddIntervalSeconds(v int) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.AddIntervalSeconds(v)
+ })
+}
+
+// UpdateIntervalSeconds sets the "interval_seconds" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateIntervalSeconds() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateIntervalSeconds()
+ })
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (u *ChannelMonitorUpsertOne) SetLastCheckedAt(v time.Time) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetLastCheckedAt(v)
+ })
+}
+
+// UpdateLastCheckedAt sets the "last_checked_at" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateLastCheckedAt() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateLastCheckedAt()
+ })
+}
+
+// ClearLastCheckedAt clears the value of the "last_checked_at" field.
+func (u *ChannelMonitorUpsertOne) ClearLastCheckedAt() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearLastCheckedAt()
+ })
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (u *ChannelMonitorUpsertOne) SetCreatedBy(v int64) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetCreatedBy(v)
+ })
+}
+
+// AddCreatedBy adds v to the "created_by" field.
+func (u *ChannelMonitorUpsertOne) AddCreatedBy(v int64) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.AddCreatedBy(v)
+ })
+}
+
+// UpdateCreatedBy sets the "created_by" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateCreatedBy() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateCreatedBy()
+ })
+}
+
+// SetTemplateID sets the "template_id" field.
+func (u *ChannelMonitorUpsertOne) SetTemplateID(v int64) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetTemplateID(v)
+ })
+}
+
+// UpdateTemplateID sets the "template_id" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateTemplateID() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateTemplateID()
+ })
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (u *ChannelMonitorUpsertOne) ClearTemplateID() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearTemplateID()
+ })
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorUpsertOne) SetExtraHeaders(v map[string]string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetExtraHeaders(v)
+ })
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateExtraHeaders() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateExtraHeaders()
+ })
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorUpsertOne) SetBodyOverrideMode(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetBodyOverrideMode(v)
+ })
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateBodyOverrideMode() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateBodyOverrideMode()
+ })
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorUpsertOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetBodyOverride(v)
+ })
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateBodyOverride() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateBodyOverride()
+ })
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorUpsertOne) ClearBodyOverride() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearBodyOverride()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *ChannelMonitorUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *ChannelMonitorUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// ChannelMonitorCreateBulk is the builder for creating many ChannelMonitor entities in bulk.
+type ChannelMonitorCreateBulk struct {
+ config
+ err error
+ builders []*ChannelMonitorCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the ChannelMonitor entities in the database.
+func (_c *ChannelMonitorCreateBulk) Save(ctx context.Context) ([]*ChannelMonitor, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*ChannelMonitor, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*ChannelMonitorMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *ChannelMonitorCreateBulk) SaveX(ctx context.Context) []*ChannelMonitor {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitor.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorUpsertBulk {
+ _c.conflict = opts
+ return &ChannelMonitorUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitor.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorUpsertBulk{
+ create: _c,
+ }
+}
+
+// ChannelMonitorUpsertBulk is the builder for "upsert"-ing
+// a bulk of ChannelMonitor nodes.
+type ChannelMonitorUpsertBulk struct {
+ create *ChannelMonitorCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.ChannelMonitor.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorUpsertBulk) UpdateNewValues() *ChannelMonitorUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(channelmonitor.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitor.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorUpsertBulk) Ignore() *ChannelMonitorUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorUpsertBulk) DoNothing() *ChannelMonitorUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorCreateBulk.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorUpsertBulk) Update(set func(*ChannelMonitorUpsert)) *ChannelMonitorUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorUpsertBulk) SetUpdatedAt(v time.Time) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateUpdatedAt() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorUpsertBulk) SetName(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateName() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorUpsertBulk) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetProvider(v)
+ })
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateProvider() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateProvider()
+ })
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (u *ChannelMonitorUpsertBulk) SetEndpoint(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetEndpoint(v)
+ })
+}
+
+// UpdateEndpoint sets the "endpoint" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateEndpoint() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateEndpoint()
+ })
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (u *ChannelMonitorUpsertBulk) SetAPIKeyEncrypted(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetAPIKeyEncrypted(v)
+ })
+}
+
+// UpdateAPIKeyEncrypted sets the "api_key_encrypted" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateAPIKeyEncrypted() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateAPIKeyEncrypted()
+ })
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (u *ChannelMonitorUpsertBulk) SetPrimaryModel(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetPrimaryModel(v)
+ })
+}
+
+// UpdatePrimaryModel sets the "primary_model" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdatePrimaryModel() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdatePrimaryModel()
+ })
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (u *ChannelMonitorUpsertBulk) SetExtraModels(v []string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetExtraModels(v)
+ })
+}
+
+// UpdateExtraModels sets the "extra_models" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateExtraModels() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateExtraModels()
+ })
+}
+
+// SetGroupName sets the "group_name" field.
+func (u *ChannelMonitorUpsertBulk) SetGroupName(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetGroupName(v)
+ })
+}
+
+// UpdateGroupName sets the "group_name" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateGroupName() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateGroupName()
+ })
+}
+
+// ClearGroupName clears the value of the "group_name" field.
+func (u *ChannelMonitorUpsertBulk) ClearGroupName() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearGroupName()
+ })
+}
+
+// SetEnabled sets the "enabled" field.
+func (u *ChannelMonitorUpsertBulk) SetEnabled(v bool) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetEnabled(v)
+ })
+}
+
+// UpdateEnabled sets the "enabled" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateEnabled() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateEnabled()
+ })
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (u *ChannelMonitorUpsertBulk) SetIntervalSeconds(v int) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetIntervalSeconds(v)
+ })
+}
+
+// AddIntervalSeconds adds v to the "interval_seconds" field.
+func (u *ChannelMonitorUpsertBulk) AddIntervalSeconds(v int) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.AddIntervalSeconds(v)
+ })
+}
+
+// UpdateIntervalSeconds sets the "interval_seconds" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateIntervalSeconds() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateIntervalSeconds()
+ })
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (u *ChannelMonitorUpsertBulk) SetLastCheckedAt(v time.Time) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetLastCheckedAt(v)
+ })
+}
+
+// UpdateLastCheckedAt sets the "last_checked_at" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateLastCheckedAt() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateLastCheckedAt()
+ })
+}
+
+// ClearLastCheckedAt clears the value of the "last_checked_at" field.
+func (u *ChannelMonitorUpsertBulk) ClearLastCheckedAt() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearLastCheckedAt()
+ })
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (u *ChannelMonitorUpsertBulk) SetCreatedBy(v int64) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetCreatedBy(v)
+ })
+}
+
+// AddCreatedBy adds v to the "created_by" field.
+func (u *ChannelMonitorUpsertBulk) AddCreatedBy(v int64) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.AddCreatedBy(v)
+ })
+}
+
+// UpdateCreatedBy sets the "created_by" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateCreatedBy() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateCreatedBy()
+ })
+}
+
+// SetTemplateID sets the "template_id" field.
+func (u *ChannelMonitorUpsertBulk) SetTemplateID(v int64) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetTemplateID(v)
+ })
+}
+
+// UpdateTemplateID sets the "template_id" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateTemplateID() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateTemplateID()
+ })
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (u *ChannelMonitorUpsertBulk) ClearTemplateID() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearTemplateID()
+ })
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorUpsertBulk) SetExtraHeaders(v map[string]string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetExtraHeaders(v)
+ })
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateExtraHeaders() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateExtraHeaders()
+ })
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorUpsertBulk) SetBodyOverrideMode(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetBodyOverrideMode(v)
+ })
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateBodyOverrideMode() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateBodyOverrideMode()
+ })
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorUpsertBulk) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetBodyOverride(v)
+ })
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateBodyOverride() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateBodyOverride()
+ })
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorUpsertBulk) ClearBodyOverride() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearBodyOverride()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitor_delete.go b/backend/ent/channelmonitor_delete.go
new file mode 100644
index 00000000..500dbb48
--- /dev/null
+++ b/backend/ent/channelmonitor_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorDelete is the builder for deleting a ChannelMonitor entity.
+type ChannelMonitorDelete struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorDelete builder.
+func (_d *ChannelMonitorDelete) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *ChannelMonitorDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *ChannelMonitorDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(channelmonitor.Table, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// ChannelMonitorDeleteOne is the builder for deleting a single ChannelMonitor entity.
+type ChannelMonitorDeleteOne struct {
+ _d *ChannelMonitorDelete
+}
+
+// Where appends a list predicates to the ChannelMonitorDelete builder.
+func (_d *ChannelMonitorDeleteOne) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *ChannelMonitorDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{channelmonitor.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitor_query.go b/backend/ent/channelmonitor_query.go
new file mode 100644
index 00000000..b6722e78
--- /dev/null
+++ b/backend/ent/channelmonitor_query.go
@@ -0,0 +1,797 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorQuery is the builder for querying ChannelMonitor entities.
+type ChannelMonitorQuery struct {
+ config
+ ctx *QueryContext
+ order []channelmonitor.OrderOption
+ inters []Interceptor
+ predicates []predicate.ChannelMonitor
+ withHistory *ChannelMonitorHistoryQuery
+ withDailyRollups *ChannelMonitorDailyRollupQuery
+ withRequestTemplate *ChannelMonitorRequestTemplateQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the ChannelMonitorQuery builder.
+func (_q *ChannelMonitorQuery) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *ChannelMonitorQuery) Limit(limit int) *ChannelMonitorQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *ChannelMonitorQuery) Offset(offset int) *ChannelMonitorQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *ChannelMonitorQuery) Unique(unique bool) *ChannelMonitorQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *ChannelMonitorQuery) Order(o ...channelmonitor.OrderOption) *ChannelMonitorQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryHistory chains the current query on the "history" edge.
+func (_q *ChannelMonitorQuery) QueryHistory() *ChannelMonitorHistoryQuery {
+ query := (&ChannelMonitorHistoryClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, selector),
+ sqlgraph.To(channelmonitorhistory.Table, channelmonitorhistory.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.HistoryTable, channelmonitor.HistoryColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryDailyRollups chains the current query on the "daily_rollups" edge.
+func (_q *ChannelMonitorQuery) QueryDailyRollups() *ChannelMonitorDailyRollupQuery {
+ query := (&ChannelMonitorDailyRollupClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, selector),
+ sqlgraph.To(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.DailyRollupsTable, channelmonitor.DailyRollupsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryRequestTemplate chains the current query on the "request_template" edge.
+func (_q *ChannelMonitorQuery) QueryRequestTemplate() *ChannelMonitorRequestTemplateQuery {
+ query := (&ChannelMonitorRequestTemplateClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, selector),
+ sqlgraph.To(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, false, channelmonitor.RequestTemplateTable, channelmonitor.RequestTemplateColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first ChannelMonitor entity from the query.
+// Returns a *NotFoundError when no ChannelMonitor was found.
+func (_q *ChannelMonitorQuery) First(ctx context.Context) (*ChannelMonitor, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{channelmonitor.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) FirstX(ctx context.Context) *ChannelMonitor {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first ChannelMonitor ID from the query.
+// Returns a *NotFoundError when no ChannelMonitor ID was found.
+func (_q *ChannelMonitorQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{channelmonitor.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single ChannelMonitor entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one ChannelMonitor entity is found.
+// Returns a *NotFoundError when no ChannelMonitor entities are found.
+func (_q *ChannelMonitorQuery) Only(ctx context.Context) (*ChannelMonitor, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{channelmonitor.Label}
+ default:
+ return nil, &NotSingularError{channelmonitor.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) OnlyX(ctx context.Context) *ChannelMonitor {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only ChannelMonitor ID in the query.
+// Returns a *NotSingularError when more than one ChannelMonitor ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *ChannelMonitorQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{channelmonitor.Label}
+ default:
+ err = &NotSingularError{channelmonitor.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of ChannelMonitors.
+func (_q *ChannelMonitorQuery) All(ctx context.Context) ([]*ChannelMonitor, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*ChannelMonitor, *ChannelMonitorQuery]()
+ return withInterceptors[[]*ChannelMonitor](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) AllX(ctx context.Context) []*ChannelMonitor {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of ChannelMonitor IDs.
+func (_q *ChannelMonitorQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(channelmonitor.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *ChannelMonitorQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *ChannelMonitorQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the ChannelMonitorQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *ChannelMonitorQuery) Clone() *ChannelMonitorQuery {
+ if _q == nil {
+ return nil
+ }
+ return &ChannelMonitorQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]channelmonitor.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.ChannelMonitor{}, _q.predicates...),
+ withHistory: _q.withHistory.Clone(),
+ withDailyRollups: _q.withDailyRollups.Clone(),
+ withRequestTemplate: _q.withRequestTemplate.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithHistory tells the query-builder to eager-load the nodes that are connected to
+// the "history" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorQuery) WithHistory(opts ...func(*ChannelMonitorHistoryQuery)) *ChannelMonitorQuery {
+ query := (&ChannelMonitorHistoryClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withHistory = query
+ return _q
+}
+
+// WithDailyRollups tells the query-builder to eager-load the nodes that are connected to
+// the "daily_rollups" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorQuery) WithDailyRollups(opts ...func(*ChannelMonitorDailyRollupQuery)) *ChannelMonitorQuery {
+ query := (&ChannelMonitorDailyRollupClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withDailyRollups = query
+ return _q
+}
+
+// WithRequestTemplate tells the query-builder to eager-load the nodes that are connected to
+// the "request_template" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorQuery) WithRequestTemplate(opts ...func(*ChannelMonitorRequestTemplateQuery)) *ChannelMonitorQuery {
+ query := (&ChannelMonitorRequestTemplateClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withRequestTemplate = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.ChannelMonitor.Query().
+// GroupBy(channelmonitor.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorQuery) GroupBy(field string, fields ...string) *ChannelMonitorGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &ChannelMonitorGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = channelmonitor.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.ChannelMonitor.Query().
+// Select(channelmonitor.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorQuery) Select(fields ...string) *ChannelMonitorSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &ChannelMonitorSelect{ChannelMonitorQuery: _q}
+ sbuild.label = channelmonitor.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a ChannelMonitorSelect configured with the given aggregations.
+func (_q *ChannelMonitorQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *ChannelMonitorQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !channelmonitor.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitor, error) {
+ var (
+ nodes = []*ChannelMonitor{}
+ _spec = _q.querySpec()
+ loadedTypes = [3]bool{
+ _q.withHistory != nil,
+ _q.withDailyRollups != nil,
+ _q.withRequestTemplate != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*ChannelMonitor).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &ChannelMonitor{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withHistory; query != nil {
+ if err := _q.loadHistory(ctx, query, nodes,
+ func(n *ChannelMonitor) { n.Edges.History = []*ChannelMonitorHistory{} },
+ func(n *ChannelMonitor, e *ChannelMonitorHistory) { n.Edges.History = append(n.Edges.History, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withDailyRollups; query != nil {
+ if err := _q.loadDailyRollups(ctx, query, nodes,
+ func(n *ChannelMonitor) { n.Edges.DailyRollups = []*ChannelMonitorDailyRollup{} },
+ func(n *ChannelMonitor, e *ChannelMonitorDailyRollup) {
+ n.Edges.DailyRollups = append(n.Edges.DailyRollups, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withRequestTemplate; query != nil {
+ if err := _q.loadRequestTemplate(ctx, query, nodes, nil,
+ func(n *ChannelMonitor, e *ChannelMonitorRequestTemplate) { n.Edges.RequestTemplate = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *ChannelMonitorQuery) loadHistory(ctx context.Context, query *ChannelMonitorHistoryQuery, nodes []*ChannelMonitor, init func(*ChannelMonitor), assign func(*ChannelMonitor, *ChannelMonitorHistory)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*ChannelMonitor)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(channelmonitorhistory.FieldMonitorID)
+ }
+ query.Where(predicate.ChannelMonitorHistory(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(channelmonitor.HistoryColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.MonitorID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "monitor_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *ChannelMonitorQuery) loadDailyRollups(ctx context.Context, query *ChannelMonitorDailyRollupQuery, nodes []*ChannelMonitor, init func(*ChannelMonitor), assign func(*ChannelMonitor, *ChannelMonitorDailyRollup)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*ChannelMonitor)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(channelmonitordailyrollup.FieldMonitorID)
+ }
+ query.Where(predicate.ChannelMonitorDailyRollup(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(channelmonitor.DailyRollupsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.MonitorID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "monitor_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *ChannelMonitorQuery) loadRequestTemplate(ctx context.Context, query *ChannelMonitorRequestTemplateQuery, nodes []*ChannelMonitor, init func(*ChannelMonitor), assign func(*ChannelMonitor, *ChannelMonitorRequestTemplate)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*ChannelMonitor)
+ for i := range nodes {
+ if nodes[i].TemplateID == nil {
+ continue
+ }
+ fk := *nodes[i].TemplateID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(channelmonitorrequesttemplate.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "template_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *ChannelMonitorQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(channelmonitor.Table, channelmonitor.Columns, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitor.FieldID)
+ for i := range fields {
+ if fields[i] != channelmonitor.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withRequestTemplate != nil {
+ _spec.Node.AddColumnOnce(channelmonitor.FieldTemplateID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *ChannelMonitorQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(channelmonitor.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = channelmonitor.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *ChannelMonitorQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *ChannelMonitorQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// ChannelMonitorGroupBy is the group-by builder for ChannelMonitor entities.
+type ChannelMonitorGroupBy struct {
+ selector
+ build *ChannelMonitorQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *ChannelMonitorGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *ChannelMonitorGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorQuery, *ChannelMonitorGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *ChannelMonitorGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// ChannelMonitorSelect is the builder for selecting fields of ChannelMonitor entities.
+type ChannelMonitorSelect struct {
+ *ChannelMonitorQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *ChannelMonitorSelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *ChannelMonitorSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorQuery, *ChannelMonitorSelect](ctx, _s.ChannelMonitorQuery, _s, _s.inters, v)
+}
+
+func (_s *ChannelMonitorSelect) sqlScan(ctx context.Context, root *ChannelMonitorQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/channelmonitor_update.go b/backend/ent/channelmonitor_update.go
new file mode 100644
index 00000000..4bbcd564
--- /dev/null
+++ b/backend/ent/channelmonitor_update.go
@@ -0,0 +1,1328 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/dialect/sql/sqljson"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorUpdate is the builder for updating ChannelMonitor entities.
+type ChannelMonitorUpdate struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorUpdate builder.
+func (_u *ChannelMonitorUpdate) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *ChannelMonitorUpdate) SetUpdatedAt(v time.Time) *ChannelMonitorUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *ChannelMonitorUpdate) SetName(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableName(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetProvider sets the "provider" field.
+func (_u *ChannelMonitorUpdate) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpdate {
+ _u.mutation.SetProvider(v)
+ return _u
+}
+
+// SetNillableProvider sets the "provider" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableProvider(v *channelmonitor.Provider) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetProvider(*v)
+ }
+ return _u
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (_u *ChannelMonitorUpdate) SetEndpoint(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetEndpoint(v)
+ return _u
+}
+
+// SetNillableEndpoint sets the "endpoint" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableEndpoint(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetEndpoint(*v)
+ }
+ return _u
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (_u *ChannelMonitorUpdate) SetAPIKeyEncrypted(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetAPIKeyEncrypted(v)
+ return _u
+}
+
+// SetNillableAPIKeyEncrypted sets the "api_key_encrypted" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableAPIKeyEncrypted(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetAPIKeyEncrypted(*v)
+ }
+ return _u
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (_u *ChannelMonitorUpdate) SetPrimaryModel(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetPrimaryModel(v)
+ return _u
+}
+
+// SetNillablePrimaryModel sets the "primary_model" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillablePrimaryModel(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetPrimaryModel(*v)
+ }
+ return _u
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (_u *ChannelMonitorUpdate) SetExtraModels(v []string) *ChannelMonitorUpdate {
+ _u.mutation.SetExtraModels(v)
+ return _u
+}
+
+// AppendExtraModels appends value to the "extra_models" field.
+func (_u *ChannelMonitorUpdate) AppendExtraModels(v []string) *ChannelMonitorUpdate {
+ _u.mutation.AppendExtraModels(v)
+ return _u
+}
+
+// SetGroupName sets the "group_name" field.
+func (_u *ChannelMonitorUpdate) SetGroupName(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetGroupName(v)
+ return _u
+}
+
+// SetNillableGroupName sets the "group_name" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableGroupName(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetGroupName(*v)
+ }
+ return _u
+}
+
+// ClearGroupName clears the value of the "group_name" field.
+func (_u *ChannelMonitorUpdate) ClearGroupName() *ChannelMonitorUpdate {
+ _u.mutation.ClearGroupName()
+ return _u
+}
+
+// SetEnabled sets the "enabled" field.
+func (_u *ChannelMonitorUpdate) SetEnabled(v bool) *ChannelMonitorUpdate {
+ _u.mutation.SetEnabled(v)
+ return _u
+}
+
+// SetNillableEnabled sets the "enabled" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableEnabled(v *bool) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetEnabled(*v)
+ }
+ return _u
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (_u *ChannelMonitorUpdate) SetIntervalSeconds(v int) *ChannelMonitorUpdate {
+ _u.mutation.ResetIntervalSeconds()
+ _u.mutation.SetIntervalSeconds(v)
+ return _u
+}
+
+// SetNillableIntervalSeconds sets the "interval_seconds" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableIntervalSeconds(v *int) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetIntervalSeconds(*v)
+ }
+ return _u
+}
+
+// AddIntervalSeconds adds value to the "interval_seconds" field.
+func (_u *ChannelMonitorUpdate) AddIntervalSeconds(v int) *ChannelMonitorUpdate {
+ _u.mutation.AddIntervalSeconds(v)
+ return _u
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (_u *ChannelMonitorUpdate) SetLastCheckedAt(v time.Time) *ChannelMonitorUpdate {
+ _u.mutation.SetLastCheckedAt(v)
+ return _u
+}
+
+// SetNillableLastCheckedAt sets the "last_checked_at" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableLastCheckedAt(v *time.Time) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetLastCheckedAt(*v)
+ }
+ return _u
+}
+
+// ClearLastCheckedAt clears the value of the "last_checked_at" field.
+func (_u *ChannelMonitorUpdate) ClearLastCheckedAt() *ChannelMonitorUpdate {
+ _u.mutation.ClearLastCheckedAt()
+ return _u
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (_u *ChannelMonitorUpdate) SetCreatedBy(v int64) *ChannelMonitorUpdate {
+ _u.mutation.ResetCreatedBy()
+ _u.mutation.SetCreatedBy(v)
+ return _u
+}
+
+// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableCreatedBy(v *int64) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetCreatedBy(*v)
+ }
+ return _u
+}
+
+// AddCreatedBy adds value to the "created_by" field.
+func (_u *ChannelMonitorUpdate) AddCreatedBy(v int64) *ChannelMonitorUpdate {
+ _u.mutation.AddCreatedBy(v)
+ return _u
+}
+
+// SetTemplateID sets the "template_id" field.
+func (_u *ChannelMonitorUpdate) SetTemplateID(v int64) *ChannelMonitorUpdate {
+ _u.mutation.SetTemplateID(v)
+ return _u
+}
+
+// SetNillableTemplateID sets the "template_id" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableTemplateID(v *int64) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetTemplateID(*v)
+ }
+ return _u
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (_u *ChannelMonitorUpdate) ClearTemplateID() *ChannelMonitorUpdate {
+ _u.mutation.ClearTemplateID()
+ return _u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_u *ChannelMonitorUpdate) SetExtraHeaders(v map[string]string) *ChannelMonitorUpdate {
+ _u.mutation.SetExtraHeaders(v)
+ return _u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_u *ChannelMonitorUpdate) SetBodyOverrideMode(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetBodyOverrideMode(v)
+ return _u
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetBodyOverrideMode(*v)
+ }
+ return _u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_u *ChannelMonitorUpdate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpdate {
+ _u.mutation.SetBodyOverride(v)
+ return _u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (_u *ChannelMonitorUpdate) ClearBodyOverride() *ChannelMonitorUpdate {
+ _u.mutation.ClearBodyOverride()
+ return _u
+}
+
+// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by IDs.
+func (_u *ChannelMonitorUpdate) AddHistoryIDs(ids ...int64) *ChannelMonitorUpdate {
+ _u.mutation.AddHistoryIDs(ids...)
+ return _u
+}
+
+// AddHistory adds the "history" edges to the ChannelMonitorHistory entity.
+func (_u *ChannelMonitorUpdate) AddHistory(v ...*ChannelMonitorHistory) *ChannelMonitorUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddHistoryIDs(ids...)
+}
+
+// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs.
+func (_u *ChannelMonitorUpdate) AddDailyRollupIDs(ids ...int64) *ChannelMonitorUpdate {
+ _u.mutation.AddDailyRollupIDs(ids...)
+ return _u
+}
+
+// AddDailyRollups adds the "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorUpdate) AddDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddDailyRollupIDs(ids...)
+}
+
+// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID.
+func (_u *ChannelMonitorUpdate) SetRequestTemplateID(id int64) *ChannelMonitorUpdate {
+ _u.mutation.SetRequestTemplateID(id)
+ return _u
+}
+
+// SetNillableRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableRequestTemplateID(id *int64) *ChannelMonitorUpdate {
+ if id != nil {
+ _u = _u.SetRequestTemplateID(*id)
+ }
+ return _u
+}
+
+// SetRequestTemplate sets the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorUpdate) SetRequestTemplate(v *ChannelMonitorRequestTemplate) *ChannelMonitorUpdate {
+ return _u.SetRequestTemplateID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorMutation object of the builder.
+func (_u *ChannelMonitorUpdate) Mutation() *ChannelMonitorMutation {
+ return _u.mutation
+}
+
+// ClearHistory clears all "history" edges to the ChannelMonitorHistory entity.
+func (_u *ChannelMonitorUpdate) ClearHistory() *ChannelMonitorUpdate {
+ _u.mutation.ClearHistory()
+ return _u
+}
+
+// RemoveHistoryIDs removes the "history" edge to ChannelMonitorHistory entities by IDs.
+func (_u *ChannelMonitorUpdate) RemoveHistoryIDs(ids ...int64) *ChannelMonitorUpdate {
+ _u.mutation.RemoveHistoryIDs(ids...)
+ return _u
+}
+
+// RemoveHistory removes "history" edges to ChannelMonitorHistory entities.
+func (_u *ChannelMonitorUpdate) RemoveHistory(v ...*ChannelMonitorHistory) *ChannelMonitorUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveHistoryIDs(ids...)
+}
+
+// ClearDailyRollups clears all "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorUpdate) ClearDailyRollups() *ChannelMonitorUpdate {
+ _u.mutation.ClearDailyRollups()
+ return _u
+}
+
+// RemoveDailyRollupIDs removes the "daily_rollups" edge to ChannelMonitorDailyRollup entities by IDs.
+func (_u *ChannelMonitorUpdate) RemoveDailyRollupIDs(ids ...int64) *ChannelMonitorUpdate {
+ _u.mutation.RemoveDailyRollupIDs(ids...)
+ return _u
+}
+
+// RemoveDailyRollups removes "daily_rollups" edges to ChannelMonitorDailyRollup entities.
+func (_u *ChannelMonitorUpdate) RemoveDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveDailyRollupIDs(ids...)
+}
+
+// ClearRequestTemplate clears the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorUpdate) ClearRequestTemplate() *ChannelMonitorUpdate {
+ _u.mutation.ClearRequestTemplate()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *ChannelMonitorUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *ChannelMonitorUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := channelmonitor.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorUpdate) check() error {
+ if v, ok := _u.mutation.Name(); ok {
+ if err := channelmonitor.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Provider(); ok {
+ if err := channelmonitor.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Endpoint(); ok {
+ if err := channelmonitor.EndpointValidator(v); err != nil {
+ return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.APIKeyEncrypted(); ok {
+ if err := channelmonitor.APIKeyEncryptedValidator(v); err != nil {
+ return &ValidationError{Name: "api_key_encrypted", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_key_encrypted": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.PrimaryModel(); ok {
+ if err := channelmonitor.PrimaryModelValidator(v); err != nil {
+ return &ValidationError{Name: "primary_model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.primary_model": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.GroupName(); ok {
+ if err := channelmonitor.GroupNameValidator(v); err != nil {
+ return &ValidationError{Name: "group_name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.group_name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.IntervalSeconds(); ok {
+ if err := channelmonitor.IntervalSecondsValidator(v); err != nil {
+ return &ValidationError{Name: "interval_seconds", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.interval_seconds": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitor.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitor.Table, channelmonitor.Columns, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitor.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(channelmonitor.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Provider(); ok {
+ _spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.Endpoint(); ok {
+ _spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.APIKeyEncrypted(); ok {
+ _spec.SetField(channelmonitor.FieldAPIKeyEncrypted, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.PrimaryModel(); ok {
+ _spec.SetField(channelmonitor.FieldPrimaryModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ExtraModels(); ok {
+ _spec.SetField(channelmonitor.FieldExtraModels, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.AppendedExtraModels(); ok {
+ _spec.AddModifier(func(u *sql.UpdateBuilder) {
+ sqljson.Append(u, channelmonitor.FieldExtraModels, value)
+ })
+ }
+ if value, ok := _u.mutation.GroupName(); ok {
+ _spec.SetField(channelmonitor.FieldGroupName, field.TypeString, value)
+ }
+ if _u.mutation.GroupNameCleared() {
+ _spec.ClearField(channelmonitor.FieldGroupName, field.TypeString)
+ }
+ if value, ok := _u.mutation.Enabled(); ok {
+ _spec.SetField(channelmonitor.FieldEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.IntervalSeconds(); ok {
+ _spec.SetField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedIntervalSeconds(); ok {
+ _spec.AddField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.LastCheckedAt(); ok {
+ _spec.SetField(channelmonitor.FieldLastCheckedAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastCheckedAtCleared() {
+ _spec.ClearField(channelmonitor.FieldLastCheckedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.CreatedBy(); ok {
+ _spec.SetField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedCreatedBy(); ok {
+ _spec.AddField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitor.FieldExtraHeaders, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverrideMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverride, field.TypeJSON, value)
+ }
+ if _u.mutation.BodyOverrideCleared() {
+ _spec.ClearField(channelmonitor.FieldBodyOverride, field.TypeJSON)
+ }
+ if _u.mutation.HistoryCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedHistoryIDs(); len(nodes) > 0 && !_u.mutation.HistoryCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.HistoryIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.DailyRollupsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedDailyRollupsIDs(); len(nodes) > 0 && !_u.mutation.DailyRollupsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.DailyRollupsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.RequestTemplateCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RequestTemplateIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitor.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// ChannelMonitorUpdateOne is the builder for updating a single ChannelMonitor entity.
+type ChannelMonitorUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *ChannelMonitorMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *ChannelMonitorUpdateOne) SetUpdatedAt(v time.Time) *ChannelMonitorUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *ChannelMonitorUpdateOne) SetName(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableName(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetProvider sets the "provider" field.
+func (_u *ChannelMonitorUpdateOne) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpdateOne {
+ _u.mutation.SetProvider(v)
+ return _u
+}
+
+// SetNillableProvider sets the "provider" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableProvider(v *channelmonitor.Provider) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetProvider(*v)
+ }
+ return _u
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (_u *ChannelMonitorUpdateOne) SetEndpoint(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetEndpoint(v)
+ return _u
+}
+
+// SetNillableEndpoint sets the "endpoint" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableEndpoint(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetEndpoint(*v)
+ }
+ return _u
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (_u *ChannelMonitorUpdateOne) SetAPIKeyEncrypted(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetAPIKeyEncrypted(v)
+ return _u
+}
+
+// SetNillableAPIKeyEncrypted sets the "api_key_encrypted" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableAPIKeyEncrypted(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetAPIKeyEncrypted(*v)
+ }
+ return _u
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (_u *ChannelMonitorUpdateOne) SetPrimaryModel(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetPrimaryModel(v)
+ return _u
+}
+
+// SetNillablePrimaryModel sets the "primary_model" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillablePrimaryModel(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetPrimaryModel(*v)
+ }
+ return _u
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (_u *ChannelMonitorUpdateOne) SetExtraModels(v []string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetExtraModels(v)
+ return _u
+}
+
+// AppendExtraModels appends value to the "extra_models" field.
+func (_u *ChannelMonitorUpdateOne) AppendExtraModels(v []string) *ChannelMonitorUpdateOne {
+ _u.mutation.AppendExtraModels(v)
+ return _u
+}
+
+// SetGroupName sets the "group_name" field.
+func (_u *ChannelMonitorUpdateOne) SetGroupName(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetGroupName(v)
+ return _u
+}
+
+// SetNillableGroupName sets the "group_name" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableGroupName(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetGroupName(*v)
+ }
+ return _u
+}
+
+// ClearGroupName clears the value of the "group_name" field.
+func (_u *ChannelMonitorUpdateOne) ClearGroupName() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearGroupName()
+ return _u
+}
+
+// SetEnabled sets the "enabled" field.
+func (_u *ChannelMonitorUpdateOne) SetEnabled(v bool) *ChannelMonitorUpdateOne {
+ _u.mutation.SetEnabled(v)
+ return _u
+}
+
+// SetNillableEnabled sets the "enabled" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableEnabled(v *bool) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetEnabled(*v)
+ }
+ return _u
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (_u *ChannelMonitorUpdateOne) SetIntervalSeconds(v int) *ChannelMonitorUpdateOne {
+ _u.mutation.ResetIntervalSeconds()
+ _u.mutation.SetIntervalSeconds(v)
+ return _u
+}
+
+// SetNillableIntervalSeconds sets the "interval_seconds" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableIntervalSeconds(v *int) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetIntervalSeconds(*v)
+ }
+ return _u
+}
+
+// AddIntervalSeconds adds value to the "interval_seconds" field.
+func (_u *ChannelMonitorUpdateOne) AddIntervalSeconds(v int) *ChannelMonitorUpdateOne {
+ _u.mutation.AddIntervalSeconds(v)
+ return _u
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (_u *ChannelMonitorUpdateOne) SetLastCheckedAt(v time.Time) *ChannelMonitorUpdateOne {
+ _u.mutation.SetLastCheckedAt(v)
+ return _u
+}
+
+// SetNillableLastCheckedAt sets the "last_checked_at" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableLastCheckedAt(v *time.Time) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetLastCheckedAt(*v)
+ }
+ return _u
+}
+
+// ClearLastCheckedAt clears the value of the "last_checked_at" field.
+func (_u *ChannelMonitorUpdateOne) ClearLastCheckedAt() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearLastCheckedAt()
+ return _u
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (_u *ChannelMonitorUpdateOne) SetCreatedBy(v int64) *ChannelMonitorUpdateOne {
+ _u.mutation.ResetCreatedBy()
+ _u.mutation.SetCreatedBy(v)
+ return _u
+}
+
+// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableCreatedBy(v *int64) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetCreatedBy(*v)
+ }
+ return _u
+}
+
+// AddCreatedBy adds value to the "created_by" field.
+func (_u *ChannelMonitorUpdateOne) AddCreatedBy(v int64) *ChannelMonitorUpdateOne {
+ _u.mutation.AddCreatedBy(v)
+ return _u
+}
+
+// SetTemplateID sets the "template_id" field.
+func (_u *ChannelMonitorUpdateOne) SetTemplateID(v int64) *ChannelMonitorUpdateOne {
+ _u.mutation.SetTemplateID(v)
+ return _u
+}
+
+// SetNillableTemplateID sets the "template_id" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableTemplateID(v *int64) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetTemplateID(*v)
+ }
+ return _u
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (_u *ChannelMonitorUpdateOne) ClearTemplateID() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearTemplateID()
+ return _u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_u *ChannelMonitorUpdateOne) SetExtraHeaders(v map[string]string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetExtraHeaders(v)
+ return _u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_u *ChannelMonitorUpdateOne) SetBodyOverrideMode(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetBodyOverrideMode(v)
+ return _u
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableBodyOverrideMode(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetBodyOverrideMode(*v)
+ }
+ return _u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_u *ChannelMonitorUpdateOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpdateOne {
+ _u.mutation.SetBodyOverride(v)
+ return _u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (_u *ChannelMonitorUpdateOne) ClearBodyOverride() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearBodyOverride()
+ return _u
+}
+
+// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by IDs.
+func (_u *ChannelMonitorUpdateOne) AddHistoryIDs(ids ...int64) *ChannelMonitorUpdateOne {
+ _u.mutation.AddHistoryIDs(ids...)
+ return _u
+}
+
+// AddHistory adds the "history" edges to the ChannelMonitorHistory entity.
+func (_u *ChannelMonitorUpdateOne) AddHistory(v ...*ChannelMonitorHistory) *ChannelMonitorUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddHistoryIDs(ids...)
+}
+
+// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs.
+func (_u *ChannelMonitorUpdateOne) AddDailyRollupIDs(ids ...int64) *ChannelMonitorUpdateOne {
+ _u.mutation.AddDailyRollupIDs(ids...)
+ return _u
+}
+
+// AddDailyRollups adds the "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorUpdateOne) AddDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddDailyRollupIDs(ids...)
+}
+
+// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID.
+func (_u *ChannelMonitorUpdateOne) SetRequestTemplateID(id int64) *ChannelMonitorUpdateOne {
+ _u.mutation.SetRequestTemplateID(id)
+ return _u
+}
+
+// SetNillableRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableRequestTemplateID(id *int64) *ChannelMonitorUpdateOne {
+ if id != nil {
+ _u = _u.SetRequestTemplateID(*id)
+ }
+ return _u
+}
+
+// SetRequestTemplate sets the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorUpdateOne) SetRequestTemplate(v *ChannelMonitorRequestTemplate) *ChannelMonitorUpdateOne {
+ return _u.SetRequestTemplateID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorMutation object of the builder.
+func (_u *ChannelMonitorUpdateOne) Mutation() *ChannelMonitorMutation {
+ return _u.mutation
+}
+
+// ClearHistory clears all "history" edges to the ChannelMonitorHistory entity.
+func (_u *ChannelMonitorUpdateOne) ClearHistory() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearHistory()
+ return _u
+}
+
+// RemoveHistoryIDs removes the "history" edge to ChannelMonitorHistory entities by IDs.
+func (_u *ChannelMonitorUpdateOne) RemoveHistoryIDs(ids ...int64) *ChannelMonitorUpdateOne {
+ _u.mutation.RemoveHistoryIDs(ids...)
+ return _u
+}
+
+// RemoveHistory removes "history" edges to ChannelMonitorHistory entities.
+func (_u *ChannelMonitorUpdateOne) RemoveHistory(v ...*ChannelMonitorHistory) *ChannelMonitorUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveHistoryIDs(ids...)
+}
+
+// ClearDailyRollups clears all "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorUpdateOne) ClearDailyRollups() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearDailyRollups()
+ return _u
+}
+
+// RemoveDailyRollupIDs removes the "daily_rollups" edge to ChannelMonitorDailyRollup entities by IDs.
+func (_u *ChannelMonitorUpdateOne) RemoveDailyRollupIDs(ids ...int64) *ChannelMonitorUpdateOne {
+ _u.mutation.RemoveDailyRollupIDs(ids...)
+ return _u
+}
+
+// RemoveDailyRollups removes "daily_rollups" edges to ChannelMonitorDailyRollup entities.
+func (_u *ChannelMonitorUpdateOne) RemoveDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveDailyRollupIDs(ids...)
+}
+
+// ClearRequestTemplate clears the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorUpdateOne) ClearRequestTemplate() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearRequestTemplate()
+ return _u
+}
+
+// Where appends a list predicates to the ChannelMonitorUpdate builder.
+func (_u *ChannelMonitorUpdateOne) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *ChannelMonitorUpdateOne) Select(field string, fields ...string) *ChannelMonitorUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated ChannelMonitor entity.
+func (_u *ChannelMonitorUpdateOne) Save(ctx context.Context) (*ChannelMonitor, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorUpdateOne) SaveX(ctx context.Context) *ChannelMonitor {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *ChannelMonitorUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := channelmonitor.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorUpdateOne) check() error {
+ if v, ok := _u.mutation.Name(); ok {
+ if err := channelmonitor.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Provider(); ok {
+ if err := channelmonitor.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Endpoint(); ok {
+ if err := channelmonitor.EndpointValidator(v); err != nil {
+ return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.APIKeyEncrypted(); ok {
+ if err := channelmonitor.APIKeyEncryptedValidator(v); err != nil {
+ return &ValidationError{Name: "api_key_encrypted", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_key_encrypted": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.PrimaryModel(); ok {
+ if err := channelmonitor.PrimaryModelValidator(v); err != nil {
+ return &ValidationError{Name: "primary_model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.primary_model": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.GroupName(); ok {
+ if err := channelmonitor.GroupNameValidator(v); err != nil {
+ return &ValidationError{Name: "group_name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.group_name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.IntervalSeconds(); ok {
+ if err := channelmonitor.IntervalSecondsValidator(v); err != nil {
+ return &ValidationError{Name: "interval_seconds", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.interval_seconds": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitor.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitor, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitor.Table, channelmonitor.Columns, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitor.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitor.FieldID)
+ for _, f := range fields {
+ if !channelmonitor.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != channelmonitor.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitor.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(channelmonitor.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Provider(); ok {
+ _spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.Endpoint(); ok {
+ _spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.APIKeyEncrypted(); ok {
+ _spec.SetField(channelmonitor.FieldAPIKeyEncrypted, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.PrimaryModel(); ok {
+ _spec.SetField(channelmonitor.FieldPrimaryModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ExtraModels(); ok {
+ _spec.SetField(channelmonitor.FieldExtraModels, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.AppendedExtraModels(); ok {
+ _spec.AddModifier(func(u *sql.UpdateBuilder) {
+ sqljson.Append(u, channelmonitor.FieldExtraModels, value)
+ })
+ }
+ if value, ok := _u.mutation.GroupName(); ok {
+ _spec.SetField(channelmonitor.FieldGroupName, field.TypeString, value)
+ }
+ if _u.mutation.GroupNameCleared() {
+ _spec.ClearField(channelmonitor.FieldGroupName, field.TypeString)
+ }
+ if value, ok := _u.mutation.Enabled(); ok {
+ _spec.SetField(channelmonitor.FieldEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.IntervalSeconds(); ok {
+ _spec.SetField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedIntervalSeconds(); ok {
+ _spec.AddField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.LastCheckedAt(); ok {
+ _spec.SetField(channelmonitor.FieldLastCheckedAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastCheckedAtCleared() {
+ _spec.ClearField(channelmonitor.FieldLastCheckedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.CreatedBy(); ok {
+ _spec.SetField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedCreatedBy(); ok {
+ _spec.AddField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitor.FieldExtraHeaders, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverrideMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverride, field.TypeJSON, value)
+ }
+ if _u.mutation.BodyOverrideCleared() {
+ _spec.ClearField(channelmonitor.FieldBodyOverride, field.TypeJSON)
+ }
+ if _u.mutation.HistoryCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedHistoryIDs(); len(nodes) > 0 && !_u.mutation.HistoryCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.HistoryIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.DailyRollupsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedDailyRollupsIDs(); len(nodes) > 0 && !_u.mutation.DailyRollupsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.DailyRollupsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.RequestTemplateCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RequestTemplateIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &ChannelMonitor{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitor.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/channelmonitordailyrollup.go b/backend/ent/channelmonitordailyrollup.go
new file mode 100644
index 00000000..78a5f489
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup.go
@@ -0,0 +1,278 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+)
+
+// ChannelMonitorDailyRollup is the model entity for the ChannelMonitorDailyRollup schema.
+type ChannelMonitorDailyRollup struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // MonitorID holds the value of the "monitor_id" field.
+ MonitorID int64 `json:"monitor_id,omitempty"`
+ // Model holds the value of the "model" field.
+ Model string `json:"model,omitempty"`
+ // BucketDate holds the value of the "bucket_date" field.
+ BucketDate time.Time `json:"bucket_date,omitempty"`
+ // TotalChecks holds the value of the "total_checks" field.
+ TotalChecks int `json:"total_checks,omitempty"`
+ // OkCount holds the value of the "ok_count" field.
+ OkCount int `json:"ok_count,omitempty"`
+ // OperationalCount holds the value of the "operational_count" field.
+ OperationalCount int `json:"operational_count,omitempty"`
+ // DegradedCount holds the value of the "degraded_count" field.
+ DegradedCount int `json:"degraded_count,omitempty"`
+ // FailedCount holds the value of the "failed_count" field.
+ FailedCount int `json:"failed_count,omitempty"`
+ // ErrorCount holds the value of the "error_count" field.
+ ErrorCount int `json:"error_count,omitempty"`
+ // SumLatencyMs holds the value of the "sum_latency_ms" field.
+ SumLatencyMs int64 `json:"sum_latency_ms,omitempty"`
+ // CountLatency holds the value of the "count_latency" field.
+ CountLatency int `json:"count_latency,omitempty"`
+ // SumPingLatencyMs holds the value of the "sum_ping_latency_ms" field.
+ SumPingLatencyMs int64 `json:"sum_ping_latency_ms,omitempty"`
+ // CountPingLatency holds the value of the "count_ping_latency" field.
+ CountPingLatency int `json:"count_ping_latency,omitempty"`
+ // ComputedAt holds the value of the "computed_at" field.
+ ComputedAt time.Time `json:"computed_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the ChannelMonitorDailyRollupQuery when eager-loading is set.
+ Edges ChannelMonitorDailyRollupEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// ChannelMonitorDailyRollupEdges holds the relations/edges for other nodes in the graph.
+type ChannelMonitorDailyRollupEdges struct {
+ // Monitor holds the value of the monitor edge.
+ Monitor *ChannelMonitor `json:"monitor,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// MonitorOrErr returns the Monitor value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e ChannelMonitorDailyRollupEdges) MonitorOrErr() (*ChannelMonitor, error) {
+ if e.Monitor != nil {
+ return e.Monitor, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: channelmonitor.Label}
+ }
+ return nil, &NotLoadedError{edge: "monitor"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*ChannelMonitorDailyRollup) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitordailyrollup.FieldID, channelmonitordailyrollup.FieldMonitorID, channelmonitordailyrollup.FieldTotalChecks, channelmonitordailyrollup.FieldOkCount, channelmonitordailyrollup.FieldOperationalCount, channelmonitordailyrollup.FieldDegradedCount, channelmonitordailyrollup.FieldFailedCount, channelmonitordailyrollup.FieldErrorCount, channelmonitordailyrollup.FieldSumLatencyMs, channelmonitordailyrollup.FieldCountLatency, channelmonitordailyrollup.FieldSumPingLatencyMs, channelmonitordailyrollup.FieldCountPingLatency:
+ values[i] = new(sql.NullInt64)
+ case channelmonitordailyrollup.FieldModel:
+ values[i] = new(sql.NullString)
+ case channelmonitordailyrollup.FieldBucketDate, channelmonitordailyrollup.FieldComputedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the ChannelMonitorDailyRollup fields.
+func (_m *ChannelMonitorDailyRollup) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitordailyrollup.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case channelmonitordailyrollup.FieldMonitorID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field monitor_id", values[i])
+ } else if value.Valid {
+ _m.MonitorID = value.Int64
+ }
+ case channelmonitordailyrollup.FieldModel:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field model", values[i])
+ } else if value.Valid {
+ _m.Model = value.String
+ }
+ case channelmonitordailyrollup.FieldBucketDate:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field bucket_date", values[i])
+ } else if value.Valid {
+ _m.BucketDate = value.Time
+ }
+ case channelmonitordailyrollup.FieldTotalChecks:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field total_checks", values[i])
+ } else if value.Valid {
+ _m.TotalChecks = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldOkCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field ok_count", values[i])
+ } else if value.Valid {
+ _m.OkCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldOperationalCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field operational_count", values[i])
+ } else if value.Valid {
+ _m.OperationalCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldDegradedCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field degraded_count", values[i])
+ } else if value.Valid {
+ _m.DegradedCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldFailedCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field failed_count", values[i])
+ } else if value.Valid {
+ _m.FailedCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldErrorCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field error_count", values[i])
+ } else if value.Valid {
+ _m.ErrorCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sum_latency_ms", values[i])
+ } else if value.Valid {
+ _m.SumLatencyMs = value.Int64
+ }
+ case channelmonitordailyrollup.FieldCountLatency:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field count_latency", values[i])
+ } else if value.Valid {
+ _m.CountLatency = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sum_ping_latency_ms", values[i])
+ } else if value.Valid {
+ _m.SumPingLatencyMs = value.Int64
+ }
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field count_ping_latency", values[i])
+ } else if value.Valid {
+ _m.CountPingLatency = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldComputedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field computed_at", values[i])
+ } else if value.Valid {
+ _m.ComputedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitorDailyRollup.
+// This includes values selected through modifiers, order, etc.
+func (_m *ChannelMonitorDailyRollup) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryMonitor queries the "monitor" edge of the ChannelMonitorDailyRollup entity.
+func (_m *ChannelMonitorDailyRollup) QueryMonitor() *ChannelMonitorQuery {
+ return NewChannelMonitorDailyRollupClient(_m.config).QueryMonitor(_m)
+}
+
+// Update returns a builder for updating this ChannelMonitorDailyRollup.
+// Note that you need to call ChannelMonitorDailyRollup.Unwrap() before calling this method if this ChannelMonitorDailyRollup
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *ChannelMonitorDailyRollup) Update() *ChannelMonitorDailyRollupUpdateOne {
+ return NewChannelMonitorDailyRollupClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the ChannelMonitorDailyRollup entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *ChannelMonitorDailyRollup) Unwrap() *ChannelMonitorDailyRollup {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: ChannelMonitorDailyRollup is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *ChannelMonitorDailyRollup) String() string {
+ var builder strings.Builder
+ builder.WriteString("ChannelMonitorDailyRollup(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("monitor_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.MonitorID))
+ builder.WriteString(", ")
+ builder.WriteString("model=")
+ builder.WriteString(_m.Model)
+ builder.WriteString(", ")
+ builder.WriteString("bucket_date=")
+ builder.WriteString(_m.BucketDate.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("total_checks=")
+ builder.WriteString(fmt.Sprintf("%v", _m.TotalChecks))
+ builder.WriteString(", ")
+ builder.WriteString("ok_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.OkCount))
+ builder.WriteString(", ")
+ builder.WriteString("operational_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.OperationalCount))
+ builder.WriteString(", ")
+ builder.WriteString("degraded_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.DegradedCount))
+ builder.WriteString(", ")
+ builder.WriteString("failed_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.FailedCount))
+ builder.WriteString(", ")
+ builder.WriteString("error_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ErrorCount))
+ builder.WriteString(", ")
+ builder.WriteString("sum_latency_ms=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SumLatencyMs))
+ builder.WriteString(", ")
+ builder.WriteString("count_latency=")
+ builder.WriteString(fmt.Sprintf("%v", _m.CountLatency))
+ builder.WriteString(", ")
+ builder.WriteString("sum_ping_latency_ms=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SumPingLatencyMs))
+ builder.WriteString(", ")
+ builder.WriteString("count_ping_latency=")
+ builder.WriteString(fmt.Sprintf("%v", _m.CountPingLatency))
+ builder.WriteString(", ")
+ builder.WriteString("computed_at=")
+ builder.WriteString(_m.ComputedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// ChannelMonitorDailyRollups is a parsable slice of ChannelMonitorDailyRollup.
+type ChannelMonitorDailyRollups []*ChannelMonitorDailyRollup
diff --git a/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go b/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go
new file mode 100644
index 00000000..e7cb9307
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go
@@ -0,0 +1,206 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitordailyrollup
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the channelmonitordailyrollup type in the database.
+ Label = "channel_monitor_daily_rollup"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldMonitorID holds the string denoting the monitor_id field in the database.
+ FieldMonitorID = "monitor_id"
+ // FieldModel holds the string denoting the model field in the database.
+ FieldModel = "model"
+ // FieldBucketDate holds the string denoting the bucket_date field in the database.
+ FieldBucketDate = "bucket_date"
+ // FieldTotalChecks holds the string denoting the total_checks field in the database.
+ FieldTotalChecks = "total_checks"
+ // FieldOkCount holds the string denoting the ok_count field in the database.
+ FieldOkCount = "ok_count"
+ // FieldOperationalCount holds the string denoting the operational_count field in the database.
+ FieldOperationalCount = "operational_count"
+ // FieldDegradedCount holds the string denoting the degraded_count field in the database.
+ FieldDegradedCount = "degraded_count"
+ // FieldFailedCount holds the string denoting the failed_count field in the database.
+ FieldFailedCount = "failed_count"
+ // FieldErrorCount holds the string denoting the error_count field in the database.
+ FieldErrorCount = "error_count"
+ // FieldSumLatencyMs holds the string denoting the sum_latency_ms field in the database.
+ FieldSumLatencyMs = "sum_latency_ms"
+ // FieldCountLatency holds the string denoting the count_latency field in the database.
+ FieldCountLatency = "count_latency"
+ // FieldSumPingLatencyMs holds the string denoting the sum_ping_latency_ms field in the database.
+ FieldSumPingLatencyMs = "sum_ping_latency_ms"
+ // FieldCountPingLatency holds the string denoting the count_ping_latency field in the database.
+ FieldCountPingLatency = "count_ping_latency"
+ // FieldComputedAt holds the string denoting the computed_at field in the database.
+ FieldComputedAt = "computed_at"
+ // EdgeMonitor holds the string denoting the monitor edge name in mutations.
+ EdgeMonitor = "monitor"
+ // Table holds the table name of the channelmonitordailyrollup in the database.
+ Table = "channel_monitor_daily_rollups"
+ // MonitorTable is the table that holds the monitor relation/edge.
+ MonitorTable = "channel_monitor_daily_rollups"
+ // MonitorInverseTable is the table name for the ChannelMonitor entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitor" package.
+ MonitorInverseTable = "channel_monitors"
+ // MonitorColumn is the table column denoting the monitor relation/edge.
+ MonitorColumn = "monitor_id"
+)
+
+// Columns holds all SQL columns for channelmonitordailyrollup fields.
+var Columns = []string{
+ FieldID,
+ FieldMonitorID,
+ FieldModel,
+ FieldBucketDate,
+ FieldTotalChecks,
+ FieldOkCount,
+ FieldOperationalCount,
+ FieldDegradedCount,
+ FieldFailedCount,
+ FieldErrorCount,
+ FieldSumLatencyMs,
+ FieldCountLatency,
+ FieldSumPingLatencyMs,
+ FieldCountPingLatency,
+ FieldComputedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // ModelValidator is a validator for the "model" field. It is called by the builders before save.
+ ModelValidator func(string) error
+ // DefaultTotalChecks holds the default value on creation for the "total_checks" field.
+ DefaultTotalChecks int
+ // DefaultOkCount holds the default value on creation for the "ok_count" field.
+ DefaultOkCount int
+ // DefaultOperationalCount holds the default value on creation for the "operational_count" field.
+ DefaultOperationalCount int
+ // DefaultDegradedCount holds the default value on creation for the "degraded_count" field.
+ DefaultDegradedCount int
+ // DefaultFailedCount holds the default value on creation for the "failed_count" field.
+ DefaultFailedCount int
+ // DefaultErrorCount holds the default value on creation for the "error_count" field.
+ DefaultErrorCount int
+ // DefaultSumLatencyMs holds the default value on creation for the "sum_latency_ms" field.
+ DefaultSumLatencyMs int64
+ // DefaultCountLatency holds the default value on creation for the "count_latency" field.
+ DefaultCountLatency int
+ // DefaultSumPingLatencyMs holds the default value on creation for the "sum_ping_latency_ms" field.
+ DefaultSumPingLatencyMs int64
+ // DefaultCountPingLatency holds the default value on creation for the "count_ping_latency" field.
+ DefaultCountPingLatency int
+ // DefaultComputedAt holds the default value on creation for the "computed_at" field.
+ DefaultComputedAt func() time.Time
+ // UpdateDefaultComputedAt holds the default value on update for the "computed_at" field.
+ UpdateDefaultComputedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the ChannelMonitorDailyRollup queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByMonitorID orders the results by the monitor_id field.
+func ByMonitorID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldMonitorID, opts...).ToFunc()
+}
+
+// ByModel orders the results by the model field.
+func ByModel(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldModel, opts...).ToFunc()
+}
+
+// ByBucketDate orders the results by the bucket_date field.
+func ByBucketDate(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBucketDate, opts...).ToFunc()
+}
+
+// ByTotalChecks orders the results by the total_checks field.
+func ByTotalChecks(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotalChecks, opts...).ToFunc()
+}
+
+// ByOkCount orders the results by the ok_count field.
+func ByOkCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldOkCount, opts...).ToFunc()
+}
+
+// ByOperationalCount orders the results by the operational_count field.
+func ByOperationalCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldOperationalCount, opts...).ToFunc()
+}
+
+// ByDegradedCount orders the results by the degraded_count field.
+func ByDegradedCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDegradedCount, opts...).ToFunc()
+}
+
+// ByFailedCount orders the results by the failed_count field.
+func ByFailedCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldFailedCount, opts...).ToFunc()
+}
+
+// ByErrorCount orders the results by the error_count field.
+func ByErrorCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldErrorCount, opts...).ToFunc()
+}
+
+// BySumLatencyMs orders the results by the sum_latency_ms field.
+func BySumLatencyMs(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSumLatencyMs, opts...).ToFunc()
+}
+
+// ByCountLatency orders the results by the count_latency field.
+func ByCountLatency(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCountLatency, opts...).ToFunc()
+}
+
+// BySumPingLatencyMs orders the results by the sum_ping_latency_ms field.
+func BySumPingLatencyMs(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSumPingLatencyMs, opts...).ToFunc()
+}
+
+// ByCountPingLatency orders the results by the count_ping_latency field.
+func ByCountPingLatency(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCountPingLatency, opts...).ToFunc()
+}
+
+// ByComputedAt orders the results by the computed_at field.
+func ByComputedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldComputedAt, opts...).ToFunc()
+}
+
+// ByMonitorField orders the results by monitor field.
+func ByMonitorField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newMonitorStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newMonitorStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(MonitorInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn),
+ )
+}
diff --git a/backend/ent/channelmonitordailyrollup/where.go b/backend/ent/channelmonitordailyrollup/where.go
new file mode 100644
index 00000000..424c957e
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup/where.go
@@ -0,0 +1,729 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitordailyrollup
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldID, id))
+}
+
+// MonitorID applies equality check predicate on the "monitor_id" field. It's identical to MonitorIDEQ.
+func MonitorID(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldMonitorID, v))
+}
+
+// Model applies equality check predicate on the "model" field. It's identical to ModelEQ.
+func Model(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldModel, v))
+}
+
+// BucketDate applies equality check predicate on the "bucket_date" field. It's identical to BucketDateEQ.
+func BucketDate(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldBucketDate, v))
+}
+
+// TotalChecks applies equality check predicate on the "total_checks" field. It's identical to TotalChecksEQ.
+func TotalChecks(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldTotalChecks, v))
+}
+
+// OkCount applies equality check predicate on the "ok_count" field. It's identical to OkCountEQ.
+func OkCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOkCount, v))
+}
+
+// OperationalCount applies equality check predicate on the "operational_count" field. It's identical to OperationalCountEQ.
+func OperationalCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOperationalCount, v))
+}
+
+// DegradedCount applies equality check predicate on the "degraded_count" field. It's identical to DegradedCountEQ.
+func DegradedCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldDegradedCount, v))
+}
+
+// FailedCount applies equality check predicate on the "failed_count" field. It's identical to FailedCountEQ.
+func FailedCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldFailedCount, v))
+}
+
+// ErrorCount applies equality check predicate on the "error_count" field. It's identical to ErrorCountEQ.
+func ErrorCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldErrorCount, v))
+}
+
+// SumLatencyMs applies equality check predicate on the "sum_latency_ms" field. It's identical to SumLatencyMsEQ.
+func SumLatencyMs(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumLatencyMs, v))
+}
+
+// CountLatency applies equality check predicate on the "count_latency" field. It's identical to CountLatencyEQ.
+func CountLatency(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountLatency, v))
+}
+
+// SumPingLatencyMs applies equality check predicate on the "sum_ping_latency_ms" field. It's identical to SumPingLatencyMsEQ.
+func SumPingLatencyMs(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumPingLatencyMs, v))
+}
+
+// CountPingLatency applies equality check predicate on the "count_ping_latency" field. It's identical to CountPingLatencyEQ.
+func CountPingLatency(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountPingLatency, v))
+}
+
+// ComputedAt applies equality check predicate on the "computed_at" field. It's identical to ComputedAtEQ.
+func ComputedAt(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldComputedAt, v))
+}
+
+// MonitorIDEQ applies the EQ predicate on the "monitor_id" field.
+func MonitorIDEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldMonitorID, v))
+}
+
+// MonitorIDNEQ applies the NEQ predicate on the "monitor_id" field.
+func MonitorIDNEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldMonitorID, v))
+}
+
+// MonitorIDIn applies the In predicate on the "monitor_id" field.
+func MonitorIDIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldMonitorID, vs...))
+}
+
+// MonitorIDNotIn applies the NotIn predicate on the "monitor_id" field.
+func MonitorIDNotIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldMonitorID, vs...))
+}
+
+// ModelEQ applies the EQ predicate on the "model" field.
+func ModelEQ(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldModel, v))
+}
+
+// ModelNEQ applies the NEQ predicate on the "model" field.
+func ModelNEQ(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldModel, v))
+}
+
+// ModelIn applies the In predicate on the "model" field.
+func ModelIn(vs ...string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldModel, vs...))
+}
+
+// ModelNotIn applies the NotIn predicate on the "model" field.
+func ModelNotIn(vs ...string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldModel, vs...))
+}
+
+// ModelGT applies the GT predicate on the "model" field.
+func ModelGT(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldModel, v))
+}
+
+// ModelGTE applies the GTE predicate on the "model" field.
+func ModelGTE(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldModel, v))
+}
+
+// ModelLT applies the LT predicate on the "model" field.
+func ModelLT(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldModel, v))
+}
+
+// ModelLTE applies the LTE predicate on the "model" field.
+func ModelLTE(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldModel, v))
+}
+
+// ModelContains applies the Contains predicate on the "model" field.
+func ModelContains(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldContains(FieldModel, v))
+}
+
+// ModelHasPrefix applies the HasPrefix predicate on the "model" field.
+func ModelHasPrefix(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldHasPrefix(FieldModel, v))
+}
+
+// ModelHasSuffix applies the HasSuffix predicate on the "model" field.
+func ModelHasSuffix(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldHasSuffix(FieldModel, v))
+}
+
+// ModelEqualFold applies the EqualFold predicate on the "model" field.
+func ModelEqualFold(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEqualFold(FieldModel, v))
+}
+
+// ModelContainsFold applies the ContainsFold predicate on the "model" field.
+func ModelContainsFold(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldContainsFold(FieldModel, v))
+}
+
+// BucketDateEQ applies the EQ predicate on the "bucket_date" field.
+func BucketDateEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldBucketDate, v))
+}
+
+// BucketDateNEQ applies the NEQ predicate on the "bucket_date" field.
+func BucketDateNEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldBucketDate, v))
+}
+
+// BucketDateIn applies the In predicate on the "bucket_date" field.
+func BucketDateIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldBucketDate, vs...))
+}
+
+// BucketDateNotIn applies the NotIn predicate on the "bucket_date" field.
+func BucketDateNotIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldBucketDate, vs...))
+}
+
+// BucketDateGT applies the GT predicate on the "bucket_date" field.
+func BucketDateGT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldBucketDate, v))
+}
+
+// BucketDateGTE applies the GTE predicate on the "bucket_date" field.
+func BucketDateGTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldBucketDate, v))
+}
+
+// BucketDateLT applies the LT predicate on the "bucket_date" field.
+func BucketDateLT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldBucketDate, v))
+}
+
+// BucketDateLTE applies the LTE predicate on the "bucket_date" field.
+func BucketDateLTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldBucketDate, v))
+}
+
+// TotalChecksEQ applies the EQ predicate on the "total_checks" field.
+func TotalChecksEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldTotalChecks, v))
+}
+
+// TotalChecksNEQ applies the NEQ predicate on the "total_checks" field.
+func TotalChecksNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldTotalChecks, v))
+}
+
+// TotalChecksIn applies the In predicate on the "total_checks" field.
+func TotalChecksIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldTotalChecks, vs...))
+}
+
+// TotalChecksNotIn applies the NotIn predicate on the "total_checks" field.
+func TotalChecksNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldTotalChecks, vs...))
+}
+
+// TotalChecksGT applies the GT predicate on the "total_checks" field.
+func TotalChecksGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldTotalChecks, v))
+}
+
+// TotalChecksGTE applies the GTE predicate on the "total_checks" field.
+func TotalChecksGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldTotalChecks, v))
+}
+
+// TotalChecksLT applies the LT predicate on the "total_checks" field.
+func TotalChecksLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldTotalChecks, v))
+}
+
+// TotalChecksLTE applies the LTE predicate on the "total_checks" field.
+func TotalChecksLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldTotalChecks, v))
+}
+
+// OkCountEQ applies the EQ predicate on the "ok_count" field.
+func OkCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOkCount, v))
+}
+
+// OkCountNEQ applies the NEQ predicate on the "ok_count" field.
+func OkCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldOkCount, v))
+}
+
+// OkCountIn applies the In predicate on the "ok_count" field.
+func OkCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldOkCount, vs...))
+}
+
+// OkCountNotIn applies the NotIn predicate on the "ok_count" field.
+func OkCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldOkCount, vs...))
+}
+
+// OkCountGT applies the GT predicate on the "ok_count" field.
+func OkCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldOkCount, v))
+}
+
+// OkCountGTE applies the GTE predicate on the "ok_count" field.
+func OkCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldOkCount, v))
+}
+
+// OkCountLT applies the LT predicate on the "ok_count" field.
+func OkCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldOkCount, v))
+}
+
+// OkCountLTE applies the LTE predicate on the "ok_count" field.
+func OkCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldOkCount, v))
+}
+
+// OperationalCountEQ applies the EQ predicate on the "operational_count" field.
+func OperationalCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOperationalCount, v))
+}
+
+// OperationalCountNEQ applies the NEQ predicate on the "operational_count" field.
+func OperationalCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldOperationalCount, v))
+}
+
+// OperationalCountIn applies the In predicate on the "operational_count" field.
+func OperationalCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldOperationalCount, vs...))
+}
+
+// OperationalCountNotIn applies the NotIn predicate on the "operational_count" field.
+func OperationalCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldOperationalCount, vs...))
+}
+
+// OperationalCountGT applies the GT predicate on the "operational_count" field.
+func OperationalCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldOperationalCount, v))
+}
+
+// OperationalCountGTE applies the GTE predicate on the "operational_count" field.
+func OperationalCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldOperationalCount, v))
+}
+
+// OperationalCountLT applies the LT predicate on the "operational_count" field.
+func OperationalCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldOperationalCount, v))
+}
+
+// OperationalCountLTE applies the LTE predicate on the "operational_count" field.
+func OperationalCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldOperationalCount, v))
+}
+
+// DegradedCountEQ applies the EQ predicate on the "degraded_count" field.
+func DegradedCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldDegradedCount, v))
+}
+
+// DegradedCountNEQ applies the NEQ predicate on the "degraded_count" field.
+func DegradedCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldDegradedCount, v))
+}
+
+// DegradedCountIn applies the In predicate on the "degraded_count" field.
+func DegradedCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldDegradedCount, vs...))
+}
+
+// DegradedCountNotIn applies the NotIn predicate on the "degraded_count" field.
+func DegradedCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldDegradedCount, vs...))
+}
+
+// DegradedCountGT applies the GT predicate on the "degraded_count" field.
+func DegradedCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldDegradedCount, v))
+}
+
+// DegradedCountGTE applies the GTE predicate on the "degraded_count" field.
+func DegradedCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldDegradedCount, v))
+}
+
+// DegradedCountLT applies the LT predicate on the "degraded_count" field.
+func DegradedCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldDegradedCount, v))
+}
+
+// DegradedCountLTE applies the LTE predicate on the "degraded_count" field.
+func DegradedCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldDegradedCount, v))
+}
+
+// FailedCountEQ applies the EQ predicate on the "failed_count" field.
+func FailedCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldFailedCount, v))
+}
+
+// FailedCountNEQ applies the NEQ predicate on the "failed_count" field.
+func FailedCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldFailedCount, v))
+}
+
+// FailedCountIn applies the In predicate on the "failed_count" field.
+func FailedCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldFailedCount, vs...))
+}
+
+// FailedCountNotIn applies the NotIn predicate on the "failed_count" field.
+func FailedCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldFailedCount, vs...))
+}
+
+// FailedCountGT applies the GT predicate on the "failed_count" field.
+func FailedCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldFailedCount, v))
+}
+
+// FailedCountGTE applies the GTE predicate on the "failed_count" field.
+func FailedCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldFailedCount, v))
+}
+
+// FailedCountLT applies the LT predicate on the "failed_count" field.
+func FailedCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldFailedCount, v))
+}
+
+// FailedCountLTE applies the LTE predicate on the "failed_count" field.
+func FailedCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldFailedCount, v))
+}
+
+// ErrorCountEQ applies the EQ predicate on the "error_count" field.
+func ErrorCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldErrorCount, v))
+}
+
+// ErrorCountNEQ applies the NEQ predicate on the "error_count" field.
+func ErrorCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldErrorCount, v))
+}
+
+// ErrorCountIn applies the In predicate on the "error_count" field.
+func ErrorCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldErrorCount, vs...))
+}
+
+// ErrorCountNotIn applies the NotIn predicate on the "error_count" field.
+func ErrorCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldErrorCount, vs...))
+}
+
+// ErrorCountGT applies the GT predicate on the "error_count" field.
+func ErrorCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldErrorCount, v))
+}
+
+// ErrorCountGTE applies the GTE predicate on the "error_count" field.
+func ErrorCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldErrorCount, v))
+}
+
+// ErrorCountLT applies the LT predicate on the "error_count" field.
+func ErrorCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldErrorCount, v))
+}
+
+// ErrorCountLTE applies the LTE predicate on the "error_count" field.
+func ErrorCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldErrorCount, v))
+}
+
+// SumLatencyMsEQ applies the EQ predicate on the "sum_latency_ms" field.
+func SumLatencyMsEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsNEQ applies the NEQ predicate on the "sum_latency_ms" field.
+func SumLatencyMsNEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsIn applies the In predicate on the "sum_latency_ms" field.
+func SumLatencyMsIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldSumLatencyMs, vs...))
+}
+
+// SumLatencyMsNotIn applies the NotIn predicate on the "sum_latency_ms" field.
+func SumLatencyMsNotIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldSumLatencyMs, vs...))
+}
+
+// SumLatencyMsGT applies the GT predicate on the "sum_latency_ms" field.
+func SumLatencyMsGT(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsGTE applies the GTE predicate on the "sum_latency_ms" field.
+func SumLatencyMsGTE(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsLT applies the LT predicate on the "sum_latency_ms" field.
+func SumLatencyMsLT(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsLTE applies the LTE predicate on the "sum_latency_ms" field.
+func SumLatencyMsLTE(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldSumLatencyMs, v))
+}
+
+// CountLatencyEQ applies the EQ predicate on the "count_latency" field.
+func CountLatencyEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountLatency, v))
+}
+
+// CountLatencyNEQ applies the NEQ predicate on the "count_latency" field.
+func CountLatencyNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldCountLatency, v))
+}
+
+// CountLatencyIn applies the In predicate on the "count_latency" field.
+func CountLatencyIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldCountLatency, vs...))
+}
+
+// CountLatencyNotIn applies the NotIn predicate on the "count_latency" field.
+func CountLatencyNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldCountLatency, vs...))
+}
+
+// CountLatencyGT applies the GT predicate on the "count_latency" field.
+func CountLatencyGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldCountLatency, v))
+}
+
+// CountLatencyGTE applies the GTE predicate on the "count_latency" field.
+func CountLatencyGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldCountLatency, v))
+}
+
+// CountLatencyLT applies the LT predicate on the "count_latency" field.
+func CountLatencyLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldCountLatency, v))
+}
+
+// CountLatencyLTE applies the LTE predicate on the "count_latency" field.
+func CountLatencyLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldCountLatency, v))
+}
+
+// SumPingLatencyMsEQ applies the EQ predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsNEQ applies the NEQ predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsNEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsIn applies the In predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldSumPingLatencyMs, vs...))
+}
+
+// SumPingLatencyMsNotIn applies the NotIn predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsNotIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldSumPingLatencyMs, vs...))
+}
+
+// SumPingLatencyMsGT applies the GT predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsGT(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsGTE applies the GTE predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsGTE(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsLT applies the LT predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsLT(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsLTE applies the LTE predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsLTE(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldSumPingLatencyMs, v))
+}
+
+// CountPingLatencyEQ applies the EQ predicate on the "count_ping_latency" field.
+func CountPingLatencyEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyNEQ applies the NEQ predicate on the "count_ping_latency" field.
+func CountPingLatencyNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyIn applies the In predicate on the "count_ping_latency" field.
+func CountPingLatencyIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldCountPingLatency, vs...))
+}
+
+// CountPingLatencyNotIn applies the NotIn predicate on the "count_ping_latency" field.
+func CountPingLatencyNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldCountPingLatency, vs...))
+}
+
+// CountPingLatencyGT applies the GT predicate on the "count_ping_latency" field.
+func CountPingLatencyGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyGTE applies the GTE predicate on the "count_ping_latency" field.
+func CountPingLatencyGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyLT applies the LT predicate on the "count_ping_latency" field.
+func CountPingLatencyLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyLTE applies the LTE predicate on the "count_ping_latency" field.
+func CountPingLatencyLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldCountPingLatency, v))
+}
+
+// ComputedAtEQ applies the EQ predicate on the "computed_at" field.
+func ComputedAtEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldComputedAt, v))
+}
+
+// ComputedAtNEQ applies the NEQ predicate on the "computed_at" field.
+func ComputedAtNEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldComputedAt, v))
+}
+
+// ComputedAtIn applies the In predicate on the "computed_at" field.
+func ComputedAtIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldComputedAt, vs...))
+}
+
+// ComputedAtNotIn applies the NotIn predicate on the "computed_at" field.
+func ComputedAtNotIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldComputedAt, vs...))
+}
+
+// ComputedAtGT applies the GT predicate on the "computed_at" field.
+func ComputedAtGT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldComputedAt, v))
+}
+
+// ComputedAtGTE applies the GTE predicate on the "computed_at" field.
+func ComputedAtGTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldComputedAt, v))
+}
+
+// ComputedAtLT applies the LT predicate on the "computed_at" field.
+func ComputedAtLT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldComputedAt, v))
+}
+
+// ComputedAtLTE applies the LTE predicate on the "computed_at" field.
+func ComputedAtLTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldComputedAt, v))
+}
+
+// HasMonitor applies the HasEdge predicate on the "monitor" edge.
+func HasMonitor() predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasMonitorWith applies the HasEdge predicate on the "monitor" edge with a given conditions (other predicates).
+func HasMonitorWith(preds ...predicate.ChannelMonitor) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(func(s *sql.Selector) {
+ step := newMonitorStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.NotPredicates(p))
+}
diff --git a/backend/ent/channelmonitordailyrollup_create.go b/backend/ent/channelmonitordailyrollup_create.go
new file mode 100644
index 00000000..5f8754ba
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup_create.go
@@ -0,0 +1,1509 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+)
+
+// ChannelMonitorDailyRollupCreate is the builder for creating a ChannelMonitorDailyRollup entity.
+type ChannelMonitorDailyRollupCreate struct {
+ config
+ mutation *ChannelMonitorDailyRollupMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetMonitorID(v int64) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetMonitorID(v)
+ return _c
+}
+
+// SetModel sets the "model" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetModel(v string) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetModel(v)
+ return _c
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetBucketDate(v)
+ return _c
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetTotalChecks(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetTotalChecks(v)
+ return _c
+}
+
+// SetNillableTotalChecks sets the "total_checks" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableTotalChecks(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetTotalChecks(*v)
+ }
+ return _c
+}
+
+// SetOkCount sets the "ok_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetOkCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetOkCount(v)
+ return _c
+}
+
+// SetNillableOkCount sets the "ok_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableOkCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetOkCount(*v)
+ }
+ return _c
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetOperationalCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetOperationalCount(v)
+ return _c
+}
+
+// SetNillableOperationalCount sets the "operational_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableOperationalCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetOperationalCount(*v)
+ }
+ return _c
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetDegradedCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetDegradedCount(v)
+ return _c
+}
+
+// SetNillableDegradedCount sets the "degraded_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableDegradedCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetDegradedCount(*v)
+ }
+ return _c
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetFailedCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetFailedCount(v)
+ return _c
+}
+
+// SetNillableFailedCount sets the "failed_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableFailedCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetFailedCount(*v)
+ }
+ return _c
+}
+
+// SetErrorCount sets the "error_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetErrorCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetErrorCount(v)
+ return _c
+}
+
+// SetNillableErrorCount sets the "error_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableErrorCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetErrorCount(*v)
+ }
+ return _c
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetSumLatencyMs(v)
+ return _c
+}
+
+// SetNillableSumLatencyMs sets the "sum_latency_ms" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableSumLatencyMs(v *int64) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetSumLatencyMs(*v)
+ }
+ return _c
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetCountLatency(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetCountLatency(v)
+ return _c
+}
+
+// SetNillableCountLatency sets the "count_latency" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableCountLatency(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetCountLatency(*v)
+ }
+ return _c
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetSumPingLatencyMs(v)
+ return _c
+}
+
+// SetNillableSumPingLatencyMs sets the "sum_ping_latency_ms" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableSumPingLatencyMs(v *int64) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetSumPingLatencyMs(*v)
+ }
+ return _c
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetCountPingLatency(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetCountPingLatency(v)
+ return _c
+}
+
+// SetNillableCountPingLatency sets the "count_ping_latency" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableCountPingLatency(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetCountPingLatency(*v)
+ }
+ return _c
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetComputedAt(v)
+ return _c
+}
+
+// SetNillableComputedAt sets the "computed_at" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableComputedAt(v *time.Time) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetComputedAt(*v)
+ }
+ return _c
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_c *ChannelMonitorDailyRollupCreate) SetMonitor(v *ChannelMonitor) *ChannelMonitorDailyRollupCreate {
+ return _c.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorDailyRollupMutation object of the builder.
+func (_c *ChannelMonitorDailyRollupCreate) Mutation() *ChannelMonitorDailyRollupMutation {
+ return _c.mutation
+}
+
+// Save creates the ChannelMonitorDailyRollup in the database.
+func (_c *ChannelMonitorDailyRollupCreate) Save(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *ChannelMonitorDailyRollupCreate) SaveX(ctx context.Context) *ChannelMonitorDailyRollup {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorDailyRollupCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorDailyRollupCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *ChannelMonitorDailyRollupCreate) defaults() {
+ if _, ok := _c.mutation.TotalChecks(); !ok {
+ v := channelmonitordailyrollup.DefaultTotalChecks
+ _c.mutation.SetTotalChecks(v)
+ }
+ if _, ok := _c.mutation.OkCount(); !ok {
+ v := channelmonitordailyrollup.DefaultOkCount
+ _c.mutation.SetOkCount(v)
+ }
+ if _, ok := _c.mutation.OperationalCount(); !ok {
+ v := channelmonitordailyrollup.DefaultOperationalCount
+ _c.mutation.SetOperationalCount(v)
+ }
+ if _, ok := _c.mutation.DegradedCount(); !ok {
+ v := channelmonitordailyrollup.DefaultDegradedCount
+ _c.mutation.SetDegradedCount(v)
+ }
+ if _, ok := _c.mutation.FailedCount(); !ok {
+ v := channelmonitordailyrollup.DefaultFailedCount
+ _c.mutation.SetFailedCount(v)
+ }
+ if _, ok := _c.mutation.ErrorCount(); !ok {
+ v := channelmonitordailyrollup.DefaultErrorCount
+ _c.mutation.SetErrorCount(v)
+ }
+ if _, ok := _c.mutation.SumLatencyMs(); !ok {
+ v := channelmonitordailyrollup.DefaultSumLatencyMs
+ _c.mutation.SetSumLatencyMs(v)
+ }
+ if _, ok := _c.mutation.CountLatency(); !ok {
+ v := channelmonitordailyrollup.DefaultCountLatency
+ _c.mutation.SetCountLatency(v)
+ }
+ if _, ok := _c.mutation.SumPingLatencyMs(); !ok {
+ v := channelmonitordailyrollup.DefaultSumPingLatencyMs
+ _c.mutation.SetSumPingLatencyMs(v)
+ }
+ if _, ok := _c.mutation.CountPingLatency(); !ok {
+ v := channelmonitordailyrollup.DefaultCountPingLatency
+ _c.mutation.SetCountPingLatency(v)
+ }
+ if _, ok := _c.mutation.ComputedAt(); !ok {
+ v := channelmonitordailyrollup.DefaultComputedAt()
+ _c.mutation.SetComputedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *ChannelMonitorDailyRollupCreate) check() error {
+ if _, ok := _c.mutation.MonitorID(); !ok {
+ return &ValidationError{Name: "monitor_id", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.monitor_id"`)}
+ }
+ if _, ok := _c.mutation.Model(); !ok {
+ return &ValidationError{Name: "model", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.model"`)}
+ }
+ if v, ok := _c.mutation.Model(); ok {
+ if err := channelmonitordailyrollup.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorDailyRollup.model": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.BucketDate(); !ok {
+ return &ValidationError{Name: "bucket_date", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.bucket_date"`)}
+ }
+ if _, ok := _c.mutation.TotalChecks(); !ok {
+ return &ValidationError{Name: "total_checks", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.total_checks"`)}
+ }
+ if _, ok := _c.mutation.OkCount(); !ok {
+ return &ValidationError{Name: "ok_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.ok_count"`)}
+ }
+ if _, ok := _c.mutation.OperationalCount(); !ok {
+ return &ValidationError{Name: "operational_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.operational_count"`)}
+ }
+ if _, ok := _c.mutation.DegradedCount(); !ok {
+ return &ValidationError{Name: "degraded_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.degraded_count"`)}
+ }
+ if _, ok := _c.mutation.FailedCount(); !ok {
+ return &ValidationError{Name: "failed_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.failed_count"`)}
+ }
+ if _, ok := _c.mutation.ErrorCount(); !ok {
+ return &ValidationError{Name: "error_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.error_count"`)}
+ }
+ if _, ok := _c.mutation.SumLatencyMs(); !ok {
+ return &ValidationError{Name: "sum_latency_ms", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.sum_latency_ms"`)}
+ }
+ if _, ok := _c.mutation.CountLatency(); !ok {
+ return &ValidationError{Name: "count_latency", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.count_latency"`)}
+ }
+ if _, ok := _c.mutation.SumPingLatencyMs(); !ok {
+ return &ValidationError{Name: "sum_ping_latency_ms", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.sum_ping_latency_ms"`)}
+ }
+ if _, ok := _c.mutation.CountPingLatency(); !ok {
+ return &ValidationError{Name: "count_ping_latency", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.count_ping_latency"`)}
+ }
+ if _, ok := _c.mutation.ComputedAt(); !ok {
+ return &ValidationError{Name: "computed_at", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.computed_at"`)}
+ }
+ if len(_c.mutation.MonitorIDs()) == 0 {
+ return &ValidationError{Name: "monitor", err: errors.New(`ent: missing required edge "ChannelMonitorDailyRollup.monitor"`)}
+ }
+ return nil
+}
+
+func (_c *ChannelMonitorDailyRollupCreate) sqlSave(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *ChannelMonitorDailyRollupCreate) createSpec() (*ChannelMonitorDailyRollup, *sqlgraph.CreateSpec) {
+ var (
+ _node = &ChannelMonitorDailyRollup{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(channelmonitordailyrollup.Table, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.Model(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value)
+ _node.Model = value
+ }
+ if value, ok := _c.mutation.BucketDate(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldBucketDate, field.TypeTime, value)
+ _node.BucketDate = value
+ }
+ if value, ok := _c.mutation.TotalChecks(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ _node.TotalChecks = value
+ }
+ if value, ok := _c.mutation.OkCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ _node.OkCount = value
+ }
+ if value, ok := _c.mutation.OperationalCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ _node.OperationalCount = value
+ }
+ if value, ok := _c.mutation.DegradedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ _node.DegradedCount = value
+ }
+ if value, ok := _c.mutation.FailedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ _node.FailedCount = value
+ }
+ if value, ok := _c.mutation.ErrorCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ _node.ErrorCount = value
+ }
+ if value, ok := _c.mutation.SumLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ _node.SumLatencyMs = value
+ }
+ if value, ok := _c.mutation.CountLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ _node.CountLatency = value
+ }
+ if value, ok := _c.mutation.SumPingLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ _node.SumPingLatencyMs = value
+ }
+ if value, ok := _c.mutation.CountPingLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ _node.CountPingLatency = value
+ }
+ if value, ok := _c.mutation.ComputedAt(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldComputedAt, field.TypeTime, value)
+ _node.ComputedAt = value
+ }
+ if nodes := _c.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.MonitorID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// SetMonitorID(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorDailyRollupUpsert) {
+// SetMonitorID(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorDailyRollupCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorDailyRollupUpsertOne {
+ _c.conflict = opts
+ return &ChannelMonitorDailyRollupUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorDailyRollupCreate) OnConflictColumns(columns ...string) *ChannelMonitorDailyRollupUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorDailyRollupUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // ChannelMonitorDailyRollupUpsertOne is the builder for "upsert"-ing
+ // one ChannelMonitorDailyRollup node.
+ ChannelMonitorDailyRollupUpsertOne struct {
+ create *ChannelMonitorDailyRollupCreate
+ }
+
+ // ChannelMonitorDailyRollupUpsert is the "OnConflict" setter.
+ ChannelMonitorDailyRollupUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldMonitorID, v)
+ return u
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateMonitorID() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldMonitorID)
+ return u
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetModel(v string) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldModel, v)
+ return u
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateModel() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldModel)
+ return u
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldBucketDate, v)
+ return u
+}
+
+// UpdateBucketDate sets the "bucket_date" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateBucketDate() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldBucketDate)
+ return u
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldTotalChecks, v)
+ return u
+}
+
+// UpdateTotalChecks sets the "total_checks" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateTotalChecks() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldTotalChecks)
+ return u
+}
+
+// AddTotalChecks adds v to the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldTotalChecks, v)
+ return u
+}
+
+// SetOkCount sets the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetOkCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldOkCount, v)
+ return u
+}
+
+// UpdateOkCount sets the "ok_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateOkCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldOkCount)
+ return u
+}
+
+// AddOkCount adds v to the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddOkCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldOkCount, v)
+ return u
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldOperationalCount, v)
+ return u
+}
+
+// UpdateOperationalCount sets the "operational_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateOperationalCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldOperationalCount)
+ return u
+}
+
+// AddOperationalCount adds v to the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldOperationalCount, v)
+ return u
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldDegradedCount, v)
+ return u
+}
+
+// UpdateDegradedCount sets the "degraded_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateDegradedCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldDegradedCount)
+ return u
+}
+
+// AddDegradedCount adds v to the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldDegradedCount, v)
+ return u
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetFailedCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldFailedCount, v)
+ return u
+}
+
+// UpdateFailedCount sets the "failed_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateFailedCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldFailedCount)
+ return u
+}
+
+// AddFailedCount adds v to the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddFailedCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldFailedCount, v)
+ return u
+}
+
+// SetErrorCount sets the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetErrorCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldErrorCount, v)
+ return u
+}
+
+// UpdateErrorCount sets the "error_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateErrorCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldErrorCount)
+ return u
+}
+
+// AddErrorCount adds v to the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddErrorCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldErrorCount, v)
+ return u
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldSumLatencyMs, v)
+ return u
+}
+
+// UpdateSumLatencyMs sets the "sum_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateSumLatencyMs() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldSumLatencyMs)
+ return u
+}
+
+// AddSumLatencyMs adds v to the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldSumLatencyMs, v)
+ return u
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetCountLatency(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldCountLatency, v)
+ return u
+}
+
+// UpdateCountLatency sets the "count_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateCountLatency() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldCountLatency)
+ return u
+}
+
+// AddCountLatency adds v to the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddCountLatency(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldCountLatency, v)
+ return u
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldSumPingLatencyMs, v)
+ return u
+}
+
+// UpdateSumPingLatencyMs sets the "sum_ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateSumPingLatencyMs() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldSumPingLatencyMs)
+ return u
+}
+
+// AddSumPingLatencyMs adds v to the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldSumPingLatencyMs, v)
+ return u
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldCountPingLatency, v)
+ return u
+}
+
+// UpdateCountPingLatency sets the "count_ping_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateCountPingLatency() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldCountPingLatency)
+ return u
+}
+
+// AddCountPingLatency adds v to the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldCountPingLatency, v)
+ return u
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldComputedAt, v)
+ return u
+}
+
+// UpdateComputedAt sets the "computed_at" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateComputedAt() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldComputedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateNewValues() *ChannelMonitorDailyRollupUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorDailyRollupUpsertOne) Ignore() *ChannelMonitorDailyRollupUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorDailyRollupUpsertOne) DoNothing() *ChannelMonitorDailyRollupUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorDailyRollupCreate.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorDailyRollupUpsertOne) Update(set func(*ChannelMonitorDailyRollupUpsert)) *ChannelMonitorDailyRollupUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorDailyRollupUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetMonitorID(v)
+ })
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateMonitorID() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateMonitorID()
+ })
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetModel(v string) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetModel(v)
+ })
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateModel() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateModel()
+ })
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetBucketDate(v)
+ })
+}
+
+// UpdateBucketDate sets the "bucket_date" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateBucketDate() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateBucketDate()
+ })
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetTotalChecks(v)
+ })
+}
+
+// AddTotalChecks adds v to the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddTotalChecks(v)
+ })
+}
+
+// UpdateTotalChecks sets the "total_checks" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateTotalChecks() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateTotalChecks()
+ })
+}
+
+// SetOkCount sets the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetOkCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetOkCount(v)
+ })
+}
+
+// AddOkCount adds v to the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddOkCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddOkCount(v)
+ })
+}
+
+// UpdateOkCount sets the "ok_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateOkCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateOkCount()
+ })
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetOperationalCount(v)
+ })
+}
+
+// AddOperationalCount adds v to the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddOperationalCount(v)
+ })
+}
+
+// UpdateOperationalCount sets the "operational_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateOperationalCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateOperationalCount()
+ })
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetDegradedCount(v)
+ })
+}
+
+// AddDegradedCount adds v to the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddDegradedCount(v)
+ })
+}
+
+// UpdateDegradedCount sets the "degraded_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateDegradedCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateDegradedCount()
+ })
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetFailedCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetFailedCount(v)
+ })
+}
+
+// AddFailedCount adds v to the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddFailedCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddFailedCount(v)
+ })
+}
+
+// UpdateFailedCount sets the "failed_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateFailedCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateFailedCount()
+ })
+}
+
+// SetErrorCount sets the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetErrorCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetErrorCount(v)
+ })
+}
+
+// AddErrorCount adds v to the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddErrorCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddErrorCount(v)
+ })
+}
+
+// UpdateErrorCount sets the "error_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateErrorCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateErrorCount()
+ })
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetSumLatencyMs(v)
+ })
+}
+
+// AddSumLatencyMs adds v to the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddSumLatencyMs(v)
+ })
+}
+
+// UpdateSumLatencyMs sets the "sum_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateSumLatencyMs() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateSumLatencyMs()
+ })
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetCountLatency(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetCountLatency(v)
+ })
+}
+
+// AddCountLatency adds v to the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddCountLatency(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddCountLatency(v)
+ })
+}
+
+// UpdateCountLatency sets the "count_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateCountLatency() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateCountLatency()
+ })
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetSumPingLatencyMs(v)
+ })
+}
+
+// AddSumPingLatencyMs adds v to the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddSumPingLatencyMs(v)
+ })
+}
+
+// UpdateSumPingLatencyMs sets the "sum_ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateSumPingLatencyMs() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateSumPingLatencyMs()
+ })
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetCountPingLatency(v)
+ })
+}
+
+// AddCountPingLatency adds v to the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddCountPingLatency(v)
+ })
+}
+
+// UpdateCountPingLatency sets the "count_ping_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateCountPingLatency() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateCountPingLatency()
+ })
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetComputedAt(v)
+ })
+}
+
+// UpdateComputedAt sets the "computed_at" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateComputedAt() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateComputedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorDailyRollupUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorDailyRollupCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorDailyRollupUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *ChannelMonitorDailyRollupUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *ChannelMonitorDailyRollupUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// ChannelMonitorDailyRollupCreateBulk is the builder for creating many ChannelMonitorDailyRollup entities in bulk.
+type ChannelMonitorDailyRollupCreateBulk struct {
+ config
+ err error
+ builders []*ChannelMonitorDailyRollupCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the ChannelMonitorDailyRollup entities in the database.
+func (_c *ChannelMonitorDailyRollupCreateBulk) Save(ctx context.Context) ([]*ChannelMonitorDailyRollup, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*ChannelMonitorDailyRollup, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*ChannelMonitorDailyRollupMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *ChannelMonitorDailyRollupCreateBulk) SaveX(ctx context.Context) []*ChannelMonitorDailyRollup {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorDailyRollupCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorDailyRollupCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorDailyRollup.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorDailyRollupUpsert) {
+// SetMonitorID(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorDailyRollupCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorDailyRollupUpsertBulk {
+ _c.conflict = opts
+ return &ChannelMonitorDailyRollupUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorDailyRollupCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorDailyRollupUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorDailyRollupUpsertBulk{
+ create: _c,
+ }
+}
+
+// ChannelMonitorDailyRollupUpsertBulk is the builder for "upsert"-ing
+// a bulk of ChannelMonitorDailyRollup nodes.
+type ChannelMonitorDailyRollupUpsertBulk struct {
+ create *ChannelMonitorDailyRollupCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateNewValues() *ChannelMonitorDailyRollupUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorDailyRollupUpsertBulk) Ignore() *ChannelMonitorDailyRollupUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorDailyRollupUpsertBulk) DoNothing() *ChannelMonitorDailyRollupUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorDailyRollupCreateBulk.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorDailyRollupUpsertBulk) Update(set func(*ChannelMonitorDailyRollupUpsert)) *ChannelMonitorDailyRollupUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorDailyRollupUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetMonitorID(v)
+ })
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateMonitorID() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateMonitorID()
+ })
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetModel(v string) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetModel(v)
+ })
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateModel() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateModel()
+ })
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetBucketDate(v)
+ })
+}
+
+// UpdateBucketDate sets the "bucket_date" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateBucketDate() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateBucketDate()
+ })
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetTotalChecks(v)
+ })
+}
+
+// AddTotalChecks adds v to the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddTotalChecks(v)
+ })
+}
+
+// UpdateTotalChecks sets the "total_checks" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateTotalChecks() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateTotalChecks()
+ })
+}
+
+// SetOkCount sets the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetOkCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetOkCount(v)
+ })
+}
+
+// AddOkCount adds v to the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddOkCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddOkCount(v)
+ })
+}
+
+// UpdateOkCount sets the "ok_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateOkCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateOkCount()
+ })
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetOperationalCount(v)
+ })
+}
+
+// AddOperationalCount adds v to the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddOperationalCount(v)
+ })
+}
+
+// UpdateOperationalCount sets the "operational_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateOperationalCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateOperationalCount()
+ })
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetDegradedCount(v)
+ })
+}
+
+// AddDegradedCount adds v to the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddDegradedCount(v)
+ })
+}
+
+// UpdateDegradedCount sets the "degraded_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateDegradedCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateDegradedCount()
+ })
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetFailedCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetFailedCount(v)
+ })
+}
+
+// AddFailedCount adds v to the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddFailedCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddFailedCount(v)
+ })
+}
+
+// UpdateFailedCount sets the "failed_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateFailedCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateFailedCount()
+ })
+}
+
+// SetErrorCount sets the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetErrorCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetErrorCount(v)
+ })
+}
+
+// AddErrorCount adds v to the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddErrorCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddErrorCount(v)
+ })
+}
+
+// UpdateErrorCount sets the "error_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateErrorCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateErrorCount()
+ })
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetSumLatencyMs(v)
+ })
+}
+
+// AddSumLatencyMs adds v to the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddSumLatencyMs(v)
+ })
+}
+
+// UpdateSumLatencyMs sets the "sum_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateSumLatencyMs() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateSumLatencyMs()
+ })
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetCountLatency(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetCountLatency(v)
+ })
+}
+
+// AddCountLatency adds v to the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddCountLatency(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddCountLatency(v)
+ })
+}
+
+// UpdateCountLatency sets the "count_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateCountLatency() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateCountLatency()
+ })
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetSumPingLatencyMs(v)
+ })
+}
+
+// AddSumPingLatencyMs adds v to the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddSumPingLatencyMs(v)
+ })
+}
+
+// UpdateSumPingLatencyMs sets the "sum_ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateSumPingLatencyMs() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateSumPingLatencyMs()
+ })
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetCountPingLatency(v)
+ })
+}
+
+// AddCountPingLatency adds v to the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddCountPingLatency(v)
+ })
+}
+
+// UpdateCountPingLatency sets the "count_ping_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateCountPingLatency() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateCountPingLatency()
+ })
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetComputedAt(v)
+ })
+}
+
+// UpdateComputedAt sets the "computed_at" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateComputedAt() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateComputedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorDailyRollupUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorDailyRollupCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorDailyRollupCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorDailyRollupUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitordailyrollup_delete.go b/backend/ent/channelmonitordailyrollup_delete.go
new file mode 100644
index 00000000..460c94f8
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorDailyRollupDelete is the builder for deleting a ChannelMonitorDailyRollup entity.
+type ChannelMonitorDailyRollupDelete struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorDailyRollupMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupDelete builder.
+func (_d *ChannelMonitorDailyRollupDelete) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *ChannelMonitorDailyRollupDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorDailyRollupDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *ChannelMonitorDailyRollupDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(channelmonitordailyrollup.Table, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// ChannelMonitorDailyRollupDeleteOne is the builder for deleting a single ChannelMonitorDailyRollup entity.
+type ChannelMonitorDailyRollupDeleteOne struct {
+ _d *ChannelMonitorDailyRollupDelete
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupDelete builder.
+func (_d *ChannelMonitorDailyRollupDeleteOne) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *ChannelMonitorDailyRollupDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{channelmonitordailyrollup.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorDailyRollupDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitordailyrollup_query.go b/backend/ent/channelmonitordailyrollup_query.go
new file mode 100644
index 00000000..e34afc61
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup_query.go
@@ -0,0 +1,643 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorDailyRollupQuery is the builder for querying ChannelMonitorDailyRollup entities.
+type ChannelMonitorDailyRollupQuery struct {
+ config
+ ctx *QueryContext
+ order []channelmonitordailyrollup.OrderOption
+ inters []Interceptor
+ predicates []predicate.ChannelMonitorDailyRollup
+ withMonitor *ChannelMonitorQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the ChannelMonitorDailyRollupQuery builder.
+func (_q *ChannelMonitorDailyRollupQuery) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *ChannelMonitorDailyRollupQuery) Limit(limit int) *ChannelMonitorDailyRollupQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *ChannelMonitorDailyRollupQuery) Offset(offset int) *ChannelMonitorDailyRollupQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *ChannelMonitorDailyRollupQuery) Unique(unique bool) *ChannelMonitorDailyRollupQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *ChannelMonitorDailyRollupQuery) Order(o ...channelmonitordailyrollup.OrderOption) *ChannelMonitorDailyRollupQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryMonitor chains the current query on the "monitor" edge.
+func (_q *ChannelMonitorDailyRollupQuery) QueryMonitor() *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID, selector),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, channelmonitordailyrollup.MonitorTable, channelmonitordailyrollup.MonitorColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first ChannelMonitorDailyRollup entity from the query.
+// Returns a *NotFoundError when no ChannelMonitorDailyRollup was found.
+func (_q *ChannelMonitorDailyRollupQuery) First(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{channelmonitordailyrollup.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) FirstX(ctx context.Context) *ChannelMonitorDailyRollup {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first ChannelMonitorDailyRollup ID from the query.
+// Returns a *NotFoundError when no ChannelMonitorDailyRollup ID was found.
+func (_q *ChannelMonitorDailyRollupQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{channelmonitordailyrollup.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single ChannelMonitorDailyRollup entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one ChannelMonitorDailyRollup entity is found.
+// Returns a *NotFoundError when no ChannelMonitorDailyRollup entities are found.
+func (_q *ChannelMonitorDailyRollupQuery) Only(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{channelmonitordailyrollup.Label}
+ default:
+ return nil, &NotSingularError{channelmonitordailyrollup.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) OnlyX(ctx context.Context) *ChannelMonitorDailyRollup {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only ChannelMonitorDailyRollup ID in the query.
+// Returns a *NotSingularError when more than one ChannelMonitorDailyRollup ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *ChannelMonitorDailyRollupQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{channelmonitordailyrollup.Label}
+ default:
+ err = &NotSingularError{channelmonitordailyrollup.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of ChannelMonitorDailyRollups.
+func (_q *ChannelMonitorDailyRollupQuery) All(ctx context.Context) ([]*ChannelMonitorDailyRollup, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*ChannelMonitorDailyRollup, *ChannelMonitorDailyRollupQuery]()
+ return withInterceptors[[]*ChannelMonitorDailyRollup](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) AllX(ctx context.Context) []*ChannelMonitorDailyRollup {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of ChannelMonitorDailyRollup IDs.
+func (_q *ChannelMonitorDailyRollupQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(channelmonitordailyrollup.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *ChannelMonitorDailyRollupQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorDailyRollupQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *ChannelMonitorDailyRollupQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the ChannelMonitorDailyRollupQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *ChannelMonitorDailyRollupQuery) Clone() *ChannelMonitorDailyRollupQuery {
+ if _q == nil {
+ return nil
+ }
+ return &ChannelMonitorDailyRollupQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]channelmonitordailyrollup.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.ChannelMonitorDailyRollup{}, _q.predicates...),
+ withMonitor: _q.withMonitor.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithMonitor tells the query-builder to eager-load the nodes that are connected to
+// the "monitor" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorDailyRollupQuery) WithMonitor(opts ...func(*ChannelMonitorQuery)) *ChannelMonitorDailyRollupQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withMonitor = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// MonitorID int64 `json:"monitor_id,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.ChannelMonitorDailyRollup.Query().
+// GroupBy(channelmonitordailyrollup.FieldMonitorID).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorDailyRollupQuery) GroupBy(field string, fields ...string) *ChannelMonitorDailyRollupGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &ChannelMonitorDailyRollupGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = channelmonitordailyrollup.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// MonitorID int64 `json:"monitor_id,omitempty"`
+// }
+//
+// client.ChannelMonitorDailyRollup.Query().
+// Select(channelmonitordailyrollup.FieldMonitorID).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorDailyRollupQuery) Select(fields ...string) *ChannelMonitorDailyRollupSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &ChannelMonitorDailyRollupSelect{ChannelMonitorDailyRollupQuery: _q}
+ sbuild.label = channelmonitordailyrollup.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a ChannelMonitorDailyRollupSelect configured with the given aggregations.
+func (_q *ChannelMonitorDailyRollupQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorDailyRollupSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !channelmonitordailyrollup.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitorDailyRollup, error) {
+ var (
+ nodes = []*ChannelMonitorDailyRollup{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withMonitor != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*ChannelMonitorDailyRollup).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &ChannelMonitorDailyRollup{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withMonitor; query != nil {
+ if err := _q.loadMonitor(ctx, query, nodes, nil,
+ func(n *ChannelMonitorDailyRollup, e *ChannelMonitor) { n.Edges.Monitor = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) loadMonitor(ctx context.Context, query *ChannelMonitorQuery, nodes []*ChannelMonitorDailyRollup, init func(*ChannelMonitorDailyRollup), assign func(*ChannelMonitorDailyRollup, *ChannelMonitor)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*ChannelMonitorDailyRollup)
+ for i := range nodes {
+ fk := nodes[i].MonitorID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(channelmonitor.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "monitor_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(channelmonitordailyrollup.Table, channelmonitordailyrollup.Columns, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitordailyrollup.FieldID)
+ for i := range fields {
+ if fields[i] != channelmonitordailyrollup.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withMonitor != nil {
+ _spec.Node.AddColumnOnce(channelmonitordailyrollup.FieldMonitorID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(channelmonitordailyrollup.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = channelmonitordailyrollup.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *ChannelMonitorDailyRollupQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorDailyRollupQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *ChannelMonitorDailyRollupQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorDailyRollupQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// ChannelMonitorDailyRollupGroupBy is the group-by builder for ChannelMonitorDailyRollup entities.
+type ChannelMonitorDailyRollupGroupBy struct {
+ selector
+ build *ChannelMonitorDailyRollupQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *ChannelMonitorDailyRollupGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorDailyRollupGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *ChannelMonitorDailyRollupGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorDailyRollupQuery, *ChannelMonitorDailyRollupGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *ChannelMonitorDailyRollupGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorDailyRollupQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// ChannelMonitorDailyRollupSelect is the builder for selecting fields of ChannelMonitorDailyRollup entities.
+type ChannelMonitorDailyRollupSelect struct {
+ *ChannelMonitorDailyRollupQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *ChannelMonitorDailyRollupSelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorDailyRollupSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *ChannelMonitorDailyRollupSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorDailyRollupQuery, *ChannelMonitorDailyRollupSelect](ctx, _s.ChannelMonitorDailyRollupQuery, _s, _s.inters, v)
+}
+
+func (_s *ChannelMonitorDailyRollupSelect) sqlScan(ctx context.Context, root *ChannelMonitorDailyRollupQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/channelmonitordailyrollup_update.go b/backend/ent/channelmonitordailyrollup_update.go
new file mode 100644
index 00000000..02cd86c5
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup_update.go
@@ -0,0 +1,961 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorDailyRollupUpdate is the builder for updating ChannelMonitorDailyRollup entities.
+type ChannelMonitorDailyRollupUpdate struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorDailyRollupMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupUpdate builder.
+func (_u *ChannelMonitorDailyRollupUpdate) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetMonitorID(v)
+ return _u
+}
+
+// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableMonitorID(v *int64) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetMonitorID(*v)
+ }
+ return _u
+}
+
+// SetModel sets the "model" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetModel(v string) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetModel(v)
+ return _u
+}
+
+// SetNillableModel sets the "model" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableModel(v *string) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetModel(*v)
+ }
+ return _u
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetBucketDate(v)
+ return _u
+}
+
+// SetNillableBucketDate sets the "bucket_date" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableBucketDate(v *time.Time) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetBucketDate(*v)
+ }
+ return _u
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetTotalChecks()
+ _u.mutation.SetTotalChecks(v)
+ return _u
+}
+
+// SetNillableTotalChecks sets the "total_checks" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableTotalChecks(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetTotalChecks(*v)
+ }
+ return _u
+}
+
+// AddTotalChecks adds value to the "total_checks" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddTotalChecks(v)
+ return _u
+}
+
+// SetOkCount sets the "ok_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetOkCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetOkCount()
+ _u.mutation.SetOkCount(v)
+ return _u
+}
+
+// SetNillableOkCount sets the "ok_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableOkCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetOkCount(*v)
+ }
+ return _u
+}
+
+// AddOkCount adds value to the "ok_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddOkCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddOkCount(v)
+ return _u
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetOperationalCount()
+ _u.mutation.SetOperationalCount(v)
+ return _u
+}
+
+// SetNillableOperationalCount sets the "operational_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableOperationalCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetOperationalCount(*v)
+ }
+ return _u
+}
+
+// AddOperationalCount adds value to the "operational_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddOperationalCount(v)
+ return _u
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetDegradedCount()
+ _u.mutation.SetDegradedCount(v)
+ return _u
+}
+
+// SetNillableDegradedCount sets the "degraded_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableDegradedCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetDegradedCount(*v)
+ }
+ return _u
+}
+
+// AddDegradedCount adds value to the "degraded_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddDegradedCount(v)
+ return _u
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetFailedCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetFailedCount()
+ _u.mutation.SetFailedCount(v)
+ return _u
+}
+
+// SetNillableFailedCount sets the "failed_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableFailedCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetFailedCount(*v)
+ }
+ return _u
+}
+
+// AddFailedCount adds value to the "failed_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddFailedCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddFailedCount(v)
+ return _u
+}
+
+// SetErrorCount sets the "error_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetErrorCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetErrorCount()
+ _u.mutation.SetErrorCount(v)
+ return _u
+}
+
+// SetNillableErrorCount sets the "error_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableErrorCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetErrorCount(*v)
+ }
+ return _u
+}
+
+// AddErrorCount adds value to the "error_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddErrorCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddErrorCount(v)
+ return _u
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetSumLatencyMs()
+ _u.mutation.SetSumLatencyMs(v)
+ return _u
+}
+
+// SetNillableSumLatencyMs sets the "sum_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableSumLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetSumLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddSumLatencyMs adds value to the "sum_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddSumLatencyMs(v)
+ return _u
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetCountLatency(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetCountLatency()
+ _u.mutation.SetCountLatency(v)
+ return _u
+}
+
+// SetNillableCountLatency sets the "count_latency" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableCountLatency(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetCountLatency(*v)
+ }
+ return _u
+}
+
+// AddCountLatency adds value to the "count_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddCountLatency(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddCountLatency(v)
+ return _u
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetSumPingLatencyMs()
+ _u.mutation.SetSumPingLatencyMs(v)
+ return _u
+}
+
+// SetNillableSumPingLatencyMs sets the "sum_ping_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableSumPingLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetSumPingLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddSumPingLatencyMs adds value to the "sum_ping_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddSumPingLatencyMs(v)
+ return _u
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetCountPingLatency()
+ _u.mutation.SetCountPingLatency(v)
+ return _u
+}
+
+// SetNillableCountPingLatency sets the "count_ping_latency" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableCountPingLatency(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetCountPingLatency(*v)
+ }
+ return _u
+}
+
+// AddCountPingLatency adds value to the "count_ping_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddCountPingLatency(v)
+ return _u
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetComputedAt(v)
+ return _u
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorDailyRollupUpdate) SetMonitor(v *ChannelMonitor) *ChannelMonitorDailyRollupUpdate {
+ return _u.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorDailyRollupMutation object of the builder.
+func (_u *ChannelMonitorDailyRollupUpdate) Mutation() *ChannelMonitorDailyRollupMutation {
+ return _u.mutation
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorDailyRollupUpdate) ClearMonitor() *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ClearMonitor()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *ChannelMonitorDailyRollupUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorDailyRollupUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *ChannelMonitorDailyRollupUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorDailyRollupUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorDailyRollupUpdate) defaults() {
+ if _, ok := _u.mutation.ComputedAt(); !ok {
+ v := channelmonitordailyrollup.UpdateDefaultComputedAt()
+ _u.mutation.SetComputedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorDailyRollupUpdate) check() error {
+ if v, ok := _u.mutation.Model(); ok {
+ if err := channelmonitordailyrollup.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorDailyRollup.model": %w`, err)}
+ }
+ }
+ if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "ChannelMonitorDailyRollup.monitor"`)
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorDailyRollupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitordailyrollup.Table, channelmonitordailyrollup.Columns, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.Model(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BucketDate(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldBucketDate, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.TotalChecks(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedTotalChecks(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.OkCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedOkCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.OperationalCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedOperationalCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.DegradedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedDegradedCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.FailedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedFailedCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ErrorCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedErrorCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.SumLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSumLatencyMs(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.CountLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedCountLatency(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.SumPingLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSumPingLatencyMs(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.CountPingLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedCountPingLatency(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ComputedAt(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldComputedAt, field.TypeTime, value)
+ }
+ if _u.mutation.MonitorCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitordailyrollup.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// ChannelMonitorDailyRollupUpdateOne is the builder for updating a single ChannelMonitorDailyRollup entity.
+type ChannelMonitorDailyRollupUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *ChannelMonitorDailyRollupMutation
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetMonitorID(v)
+ return _u
+}
+
+// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableMonitorID(v *int64) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetMonitorID(*v)
+ }
+ return _u
+}
+
+// SetModel sets the "model" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetModel(v string) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetModel(v)
+ return _u
+}
+
+// SetNillableModel sets the "model" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableModel(v *string) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetModel(*v)
+ }
+ return _u
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetBucketDate(v)
+ return _u
+}
+
+// SetNillableBucketDate sets the "bucket_date" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableBucketDate(v *time.Time) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetBucketDate(*v)
+ }
+ return _u
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetTotalChecks()
+ _u.mutation.SetTotalChecks(v)
+ return _u
+}
+
+// SetNillableTotalChecks sets the "total_checks" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableTotalChecks(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetTotalChecks(*v)
+ }
+ return _u
+}
+
+// AddTotalChecks adds value to the "total_checks" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddTotalChecks(v)
+ return _u
+}
+
+// SetOkCount sets the "ok_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetOkCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetOkCount()
+ _u.mutation.SetOkCount(v)
+ return _u
+}
+
+// SetNillableOkCount sets the "ok_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableOkCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetOkCount(*v)
+ }
+ return _u
+}
+
+// AddOkCount adds value to the "ok_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddOkCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddOkCount(v)
+ return _u
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetOperationalCount()
+ _u.mutation.SetOperationalCount(v)
+ return _u
+}
+
+// SetNillableOperationalCount sets the "operational_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableOperationalCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetOperationalCount(*v)
+ }
+ return _u
+}
+
+// AddOperationalCount adds value to the "operational_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddOperationalCount(v)
+ return _u
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetDegradedCount()
+ _u.mutation.SetDegradedCount(v)
+ return _u
+}
+
+// SetNillableDegradedCount sets the "degraded_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableDegradedCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetDegradedCount(*v)
+ }
+ return _u
+}
+
+// AddDegradedCount adds value to the "degraded_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddDegradedCount(v)
+ return _u
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetFailedCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetFailedCount()
+ _u.mutation.SetFailedCount(v)
+ return _u
+}
+
+// SetNillableFailedCount sets the "failed_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableFailedCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetFailedCount(*v)
+ }
+ return _u
+}
+
+// AddFailedCount adds value to the "failed_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddFailedCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddFailedCount(v)
+ return _u
+}
+
+// SetErrorCount sets the "error_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetErrorCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetErrorCount()
+ _u.mutation.SetErrorCount(v)
+ return _u
+}
+
+// SetNillableErrorCount sets the "error_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableErrorCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetErrorCount(*v)
+ }
+ return _u
+}
+
+// AddErrorCount adds value to the "error_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddErrorCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddErrorCount(v)
+ return _u
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetSumLatencyMs()
+ _u.mutation.SetSumLatencyMs(v)
+ return _u
+}
+
+// SetNillableSumLatencyMs sets the "sum_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableSumLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetSumLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddSumLatencyMs adds value to the "sum_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddSumLatencyMs(v)
+ return _u
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetCountLatency(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetCountLatency()
+ _u.mutation.SetCountLatency(v)
+ return _u
+}
+
+// SetNillableCountLatency sets the "count_latency" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableCountLatency(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetCountLatency(*v)
+ }
+ return _u
+}
+
+// AddCountLatency adds value to the "count_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddCountLatency(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddCountLatency(v)
+ return _u
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetSumPingLatencyMs()
+ _u.mutation.SetSumPingLatencyMs(v)
+ return _u
+}
+
+// SetNillableSumPingLatencyMs sets the "sum_ping_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableSumPingLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetSumPingLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddSumPingLatencyMs adds value to the "sum_ping_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddSumPingLatencyMs(v)
+ return _u
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetCountPingLatency()
+ _u.mutation.SetCountPingLatency(v)
+ return _u
+}
+
+// SetNillableCountPingLatency sets the "count_ping_latency" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableCountPingLatency(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetCountPingLatency(*v)
+ }
+ return _u
+}
+
+// AddCountPingLatency adds value to the "count_ping_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddCountPingLatency(v)
+ return _u
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetComputedAt(v)
+ return _u
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetMonitor(v *ChannelMonitor) *ChannelMonitorDailyRollupUpdateOne {
+ return _u.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorDailyRollupMutation object of the builder.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Mutation() *ChannelMonitorDailyRollupMutation {
+ return _u.mutation
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorDailyRollupUpdateOne) ClearMonitor() *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ClearMonitor()
+ return _u
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupUpdate builder.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Select(field string, fields ...string) *ChannelMonitorDailyRollupUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Save(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SaveX(ctx context.Context) *ChannelMonitorDailyRollup {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorDailyRollupUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorDailyRollupUpdateOne) defaults() {
+ if _, ok := _u.mutation.ComputedAt(); !ok {
+ v := channelmonitordailyrollup.UpdateDefaultComputedAt()
+ _u.mutation.SetComputedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorDailyRollupUpdateOne) check() error {
+ if v, ok := _u.mutation.Model(); ok {
+ if err := channelmonitordailyrollup.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorDailyRollup.model": %w`, err)}
+ }
+ }
+ if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "ChannelMonitorDailyRollup.monitor"`)
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorDailyRollupUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitorDailyRollup, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitordailyrollup.Table, channelmonitordailyrollup.Columns, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitorDailyRollup.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitordailyrollup.FieldID)
+ for _, f := range fields {
+ if !channelmonitordailyrollup.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != channelmonitordailyrollup.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.Model(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BucketDate(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldBucketDate, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.TotalChecks(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedTotalChecks(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.OkCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedOkCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.OperationalCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedOperationalCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.DegradedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedDegradedCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.FailedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedFailedCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ErrorCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedErrorCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.SumLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSumLatencyMs(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.CountLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedCountLatency(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.SumPingLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSumPingLatencyMs(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.CountPingLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedCountPingLatency(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ComputedAt(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldComputedAt, field.TypeTime, value)
+ }
+ if _u.mutation.MonitorCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &ChannelMonitorDailyRollup{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitordailyrollup.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/channelmonitorhistory.go b/backend/ent/channelmonitorhistory.go
new file mode 100644
index 00000000..70dde542
--- /dev/null
+++ b/backend/ent/channelmonitorhistory.go
@@ -0,0 +1,207 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+)
+
+// ChannelMonitorHistory is the model entity for the ChannelMonitorHistory schema.
+type ChannelMonitorHistory struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // MonitorID holds the value of the "monitor_id" field.
+ MonitorID int64 `json:"monitor_id,omitempty"`
+ // Model holds the value of the "model" field.
+ Model string `json:"model,omitempty"`
+ // Status holds the value of the "status" field.
+ Status channelmonitorhistory.Status `json:"status,omitempty"`
+ // LatencyMs holds the value of the "latency_ms" field.
+ LatencyMs *int `json:"latency_ms,omitempty"`
+ // PingLatencyMs holds the value of the "ping_latency_ms" field.
+ PingLatencyMs *int `json:"ping_latency_ms,omitempty"`
+ // Message holds the value of the "message" field.
+ Message string `json:"message,omitempty"`
+ // CheckedAt holds the value of the "checked_at" field.
+ CheckedAt time.Time `json:"checked_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the ChannelMonitorHistoryQuery when eager-loading is set.
+ Edges ChannelMonitorHistoryEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// ChannelMonitorHistoryEdges holds the relations/edges for other nodes in the graph.
+type ChannelMonitorHistoryEdges struct {
+ // Monitor holds the value of the monitor edge.
+ Monitor *ChannelMonitor `json:"monitor,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// MonitorOrErr returns the Monitor value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e ChannelMonitorHistoryEdges) MonitorOrErr() (*ChannelMonitor, error) {
+ if e.Monitor != nil {
+ return e.Monitor, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: channelmonitor.Label}
+ }
+ return nil, &NotLoadedError{edge: "monitor"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*ChannelMonitorHistory) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitorhistory.FieldID, channelmonitorhistory.FieldMonitorID, channelmonitorhistory.FieldLatencyMs, channelmonitorhistory.FieldPingLatencyMs:
+ values[i] = new(sql.NullInt64)
+ case channelmonitorhistory.FieldModel, channelmonitorhistory.FieldStatus, channelmonitorhistory.FieldMessage:
+ values[i] = new(sql.NullString)
+ case channelmonitorhistory.FieldCheckedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the ChannelMonitorHistory fields.
+func (_m *ChannelMonitorHistory) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitorhistory.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case channelmonitorhistory.FieldMonitorID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field monitor_id", values[i])
+ } else if value.Valid {
+ _m.MonitorID = value.Int64
+ }
+ case channelmonitorhistory.FieldModel:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field model", values[i])
+ } else if value.Valid {
+ _m.Model = value.String
+ }
+ case channelmonitorhistory.FieldStatus:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field status", values[i])
+ } else if value.Valid {
+ _m.Status = channelmonitorhistory.Status(value.String)
+ }
+ case channelmonitorhistory.FieldLatencyMs:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field latency_ms", values[i])
+ } else if value.Valid {
+ _m.LatencyMs = new(int)
+ *_m.LatencyMs = int(value.Int64)
+ }
+ case channelmonitorhistory.FieldPingLatencyMs:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field ping_latency_ms", values[i])
+ } else if value.Valid {
+ _m.PingLatencyMs = new(int)
+ *_m.PingLatencyMs = int(value.Int64)
+ }
+ case channelmonitorhistory.FieldMessage:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field message", values[i])
+ } else if value.Valid {
+ _m.Message = value.String
+ }
+ case channelmonitorhistory.FieldCheckedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field checked_at", values[i])
+ } else if value.Valid {
+ _m.CheckedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitorHistory.
+// This includes values selected through modifiers, order, etc.
+func (_m *ChannelMonitorHistory) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryMonitor queries the "monitor" edge of the ChannelMonitorHistory entity.
+func (_m *ChannelMonitorHistory) QueryMonitor() *ChannelMonitorQuery {
+ return NewChannelMonitorHistoryClient(_m.config).QueryMonitor(_m)
+}
+
+// Update returns a builder for updating this ChannelMonitorHistory.
+// Note that you need to call ChannelMonitorHistory.Unwrap() before calling this method if this ChannelMonitorHistory
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *ChannelMonitorHistory) Update() *ChannelMonitorHistoryUpdateOne {
+ return NewChannelMonitorHistoryClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the ChannelMonitorHistory entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *ChannelMonitorHistory) Unwrap() *ChannelMonitorHistory {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: ChannelMonitorHistory is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *ChannelMonitorHistory) String() string {
+ var builder strings.Builder
+ builder.WriteString("ChannelMonitorHistory(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("monitor_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.MonitorID))
+ builder.WriteString(", ")
+ builder.WriteString("model=")
+ builder.WriteString(_m.Model)
+ builder.WriteString(", ")
+ builder.WriteString("status=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Status))
+ builder.WriteString(", ")
+ if v := _m.LatencyMs; v != nil {
+ builder.WriteString("latency_ms=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ if v := _m.PingLatencyMs; v != nil {
+ builder.WriteString("ping_latency_ms=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("message=")
+ builder.WriteString(_m.Message)
+ builder.WriteString(", ")
+ builder.WriteString("checked_at=")
+ builder.WriteString(_m.CheckedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// ChannelMonitorHistories is a parsable slice of ChannelMonitorHistory.
+type ChannelMonitorHistories []*ChannelMonitorHistory
diff --git a/backend/ent/channelmonitorhistory/channelmonitorhistory.go b/backend/ent/channelmonitorhistory/channelmonitorhistory.go
new file mode 100644
index 00000000..6a9dc006
--- /dev/null
+++ b/backend/ent/channelmonitorhistory/channelmonitorhistory.go
@@ -0,0 +1,158 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitorhistory
+
+import (
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the channelmonitorhistory type in the database.
+ Label = "channel_monitor_history"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldMonitorID holds the string denoting the monitor_id field in the database.
+ FieldMonitorID = "monitor_id"
+ // FieldModel holds the string denoting the model field in the database.
+ FieldModel = "model"
+ // FieldStatus holds the string denoting the status field in the database.
+ FieldStatus = "status"
+ // FieldLatencyMs holds the string denoting the latency_ms field in the database.
+ FieldLatencyMs = "latency_ms"
+ // FieldPingLatencyMs holds the string denoting the ping_latency_ms field in the database.
+ FieldPingLatencyMs = "ping_latency_ms"
+ // FieldMessage holds the string denoting the message field in the database.
+ FieldMessage = "message"
+ // FieldCheckedAt holds the string denoting the checked_at field in the database.
+ FieldCheckedAt = "checked_at"
+ // EdgeMonitor holds the string denoting the monitor edge name in mutations.
+ EdgeMonitor = "monitor"
+ // Table holds the table name of the channelmonitorhistory in the database.
+ Table = "channel_monitor_histories"
+ // MonitorTable is the table that holds the monitor relation/edge.
+ MonitorTable = "channel_monitor_histories"
+ // MonitorInverseTable is the table name for the ChannelMonitor entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitor" package.
+ MonitorInverseTable = "channel_monitors"
+ // MonitorColumn is the table column denoting the monitor relation/edge.
+ MonitorColumn = "monitor_id"
+)
+
+// Columns holds all SQL columns for channelmonitorhistory fields.
+var Columns = []string{
+ FieldID,
+ FieldMonitorID,
+ FieldModel,
+ FieldStatus,
+ FieldLatencyMs,
+ FieldPingLatencyMs,
+ FieldMessage,
+ FieldCheckedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // ModelValidator is a validator for the "model" field. It is called by the builders before save.
+ ModelValidator func(string) error
+ // DefaultMessage holds the default value on creation for the "message" field.
+ DefaultMessage string
+ // MessageValidator is a validator for the "message" field. It is called by the builders before save.
+ MessageValidator func(string) error
+ // DefaultCheckedAt holds the default value on creation for the "checked_at" field.
+ DefaultCheckedAt func() time.Time
+)
+
+// Status defines the type for the "status" enum field.
+type Status string
+
+// Status values.
+const (
+ StatusOperational Status = "operational"
+ StatusDegraded Status = "degraded"
+ StatusFailed Status = "failed"
+ StatusError Status = "error"
+)
+
+func (s Status) String() string {
+ return string(s)
+}
+
+// StatusValidator is a validator for the "status" field enum values. It is called by the builders before save.
+func StatusValidator(s Status) error {
+ switch s {
+ case StatusOperational, StatusDegraded, StatusFailed, StatusError:
+ return nil
+ default:
+ return fmt.Errorf("channelmonitorhistory: invalid enum value for status field: %q", s)
+ }
+}
+
+// OrderOption defines the ordering options for the ChannelMonitorHistory queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByMonitorID orders the results by the monitor_id field.
+func ByMonitorID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldMonitorID, opts...).ToFunc()
+}
+
+// ByModel orders the results by the model field.
+func ByModel(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldModel, opts...).ToFunc()
+}
+
+// ByStatus orders the results by the status field.
+func ByStatus(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldStatus, opts...).ToFunc()
+}
+
+// ByLatencyMs orders the results by the latency_ms field.
+func ByLatencyMs(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLatencyMs, opts...).ToFunc()
+}
+
+// ByPingLatencyMs orders the results by the ping_latency_ms field.
+func ByPingLatencyMs(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPingLatencyMs, opts...).ToFunc()
+}
+
+// ByMessage orders the results by the message field.
+func ByMessage(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldMessage, opts...).ToFunc()
+}
+
+// ByCheckedAt orders the results by the checked_at field.
+func ByCheckedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCheckedAt, opts...).ToFunc()
+}
+
+// ByMonitorField orders the results by monitor field.
+func ByMonitorField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newMonitorStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newMonitorStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(MonitorInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn),
+ )
+}
diff --git a/backend/ent/channelmonitorhistory/where.go b/backend/ent/channelmonitorhistory/where.go
new file mode 100644
index 00000000..afa73f35
--- /dev/null
+++ b/backend/ent/channelmonitorhistory/where.go
@@ -0,0 +1,444 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitorhistory
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldID, id))
+}
+
+// MonitorID applies equality check predicate on the "monitor_id" field. It's identical to MonitorIDEQ.
+func MonitorID(v int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMonitorID, v))
+}
+
+// Model applies equality check predicate on the "model" field. It's identical to ModelEQ.
+func Model(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldModel, v))
+}
+
+// LatencyMs applies equality check predicate on the "latency_ms" field. It's identical to LatencyMsEQ.
+func LatencyMs(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldLatencyMs, v))
+}
+
+// PingLatencyMs applies equality check predicate on the "ping_latency_ms" field. It's identical to PingLatencyMsEQ.
+func PingLatencyMs(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldPingLatencyMs, v))
+}
+
+// Message applies equality check predicate on the "message" field. It's identical to MessageEQ.
+func Message(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMessage, v))
+}
+
+// CheckedAt applies equality check predicate on the "checked_at" field. It's identical to CheckedAtEQ.
+func CheckedAt(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldCheckedAt, v))
+}
+
+// MonitorIDEQ applies the EQ predicate on the "monitor_id" field.
+func MonitorIDEQ(v int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMonitorID, v))
+}
+
+// MonitorIDNEQ applies the NEQ predicate on the "monitor_id" field.
+func MonitorIDNEQ(v int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldMonitorID, v))
+}
+
+// MonitorIDIn applies the In predicate on the "monitor_id" field.
+func MonitorIDIn(vs ...int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldMonitorID, vs...))
+}
+
+// MonitorIDNotIn applies the NotIn predicate on the "monitor_id" field.
+func MonitorIDNotIn(vs ...int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldMonitorID, vs...))
+}
+
+// ModelEQ applies the EQ predicate on the "model" field.
+func ModelEQ(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldModel, v))
+}
+
+// ModelNEQ applies the NEQ predicate on the "model" field.
+func ModelNEQ(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldModel, v))
+}
+
+// ModelIn applies the In predicate on the "model" field.
+func ModelIn(vs ...string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldModel, vs...))
+}
+
+// ModelNotIn applies the NotIn predicate on the "model" field.
+func ModelNotIn(vs ...string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldModel, vs...))
+}
+
+// ModelGT applies the GT predicate on the "model" field.
+func ModelGT(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldModel, v))
+}
+
+// ModelGTE applies the GTE predicate on the "model" field.
+func ModelGTE(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldModel, v))
+}
+
+// ModelLT applies the LT predicate on the "model" field.
+func ModelLT(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldModel, v))
+}
+
+// ModelLTE applies the LTE predicate on the "model" field.
+func ModelLTE(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldModel, v))
+}
+
+// ModelContains applies the Contains predicate on the "model" field.
+func ModelContains(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldContains(FieldModel, v))
+}
+
+// ModelHasPrefix applies the HasPrefix predicate on the "model" field.
+func ModelHasPrefix(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldHasPrefix(FieldModel, v))
+}
+
+// ModelHasSuffix applies the HasSuffix predicate on the "model" field.
+func ModelHasSuffix(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldHasSuffix(FieldModel, v))
+}
+
+// ModelEqualFold applies the EqualFold predicate on the "model" field.
+func ModelEqualFold(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEqualFold(FieldModel, v))
+}
+
+// ModelContainsFold applies the ContainsFold predicate on the "model" field.
+func ModelContainsFold(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldContainsFold(FieldModel, v))
+}
+
+// StatusEQ applies the EQ predicate on the "status" field.
+func StatusEQ(v Status) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldStatus, v))
+}
+
+// StatusNEQ applies the NEQ predicate on the "status" field.
+func StatusNEQ(v Status) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldStatus, v))
+}
+
+// StatusIn applies the In predicate on the "status" field.
+func StatusIn(vs ...Status) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldStatus, vs...))
+}
+
+// StatusNotIn applies the NotIn predicate on the "status" field.
+func StatusNotIn(vs ...Status) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldStatus, vs...))
+}
+
+// LatencyMsEQ applies the EQ predicate on the "latency_ms" field.
+func LatencyMsEQ(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldLatencyMs, v))
+}
+
+// LatencyMsNEQ applies the NEQ predicate on the "latency_ms" field.
+func LatencyMsNEQ(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldLatencyMs, v))
+}
+
+// LatencyMsIn applies the In predicate on the "latency_ms" field.
+func LatencyMsIn(vs ...int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldLatencyMs, vs...))
+}
+
+// LatencyMsNotIn applies the NotIn predicate on the "latency_ms" field.
+func LatencyMsNotIn(vs ...int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldLatencyMs, vs...))
+}
+
+// LatencyMsGT applies the GT predicate on the "latency_ms" field.
+func LatencyMsGT(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldLatencyMs, v))
+}
+
+// LatencyMsGTE applies the GTE predicate on the "latency_ms" field.
+func LatencyMsGTE(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldLatencyMs, v))
+}
+
+// LatencyMsLT applies the LT predicate on the "latency_ms" field.
+func LatencyMsLT(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldLatencyMs, v))
+}
+
+// LatencyMsLTE applies the LTE predicate on the "latency_ms" field.
+func LatencyMsLTE(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldLatencyMs, v))
+}
+
+// LatencyMsIsNil applies the IsNil predicate on the "latency_ms" field.
+func LatencyMsIsNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIsNull(FieldLatencyMs))
+}
+
+// LatencyMsNotNil applies the NotNil predicate on the "latency_ms" field.
+func LatencyMsNotNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotNull(FieldLatencyMs))
+}
+
+// PingLatencyMsEQ applies the EQ predicate on the "ping_latency_ms" field.
+func PingLatencyMsEQ(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldPingLatencyMs, v))
+}
+
+// PingLatencyMsNEQ applies the NEQ predicate on the "ping_latency_ms" field.
+func PingLatencyMsNEQ(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldPingLatencyMs, v))
+}
+
+// PingLatencyMsIn applies the In predicate on the "ping_latency_ms" field.
+func PingLatencyMsIn(vs ...int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldPingLatencyMs, vs...))
+}
+
+// PingLatencyMsNotIn applies the NotIn predicate on the "ping_latency_ms" field.
+func PingLatencyMsNotIn(vs ...int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldPingLatencyMs, vs...))
+}
+
+// PingLatencyMsGT applies the GT predicate on the "ping_latency_ms" field.
+func PingLatencyMsGT(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldPingLatencyMs, v))
+}
+
+// PingLatencyMsGTE applies the GTE predicate on the "ping_latency_ms" field.
+func PingLatencyMsGTE(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldPingLatencyMs, v))
+}
+
+// PingLatencyMsLT applies the LT predicate on the "ping_latency_ms" field.
+func PingLatencyMsLT(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldPingLatencyMs, v))
+}
+
+// PingLatencyMsLTE applies the LTE predicate on the "ping_latency_ms" field.
+func PingLatencyMsLTE(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldPingLatencyMs, v))
+}
+
+// PingLatencyMsIsNil applies the IsNil predicate on the "ping_latency_ms" field.
+func PingLatencyMsIsNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIsNull(FieldPingLatencyMs))
+}
+
+// PingLatencyMsNotNil applies the NotNil predicate on the "ping_latency_ms" field.
+func PingLatencyMsNotNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotNull(FieldPingLatencyMs))
+}
+
+// MessageEQ applies the EQ predicate on the "message" field.
+func MessageEQ(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMessage, v))
+}
+
+// MessageNEQ applies the NEQ predicate on the "message" field.
+func MessageNEQ(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldMessage, v))
+}
+
+// MessageIn applies the In predicate on the "message" field.
+func MessageIn(vs ...string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldMessage, vs...))
+}
+
+// MessageNotIn applies the NotIn predicate on the "message" field.
+func MessageNotIn(vs ...string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldMessage, vs...))
+}
+
+// MessageGT applies the GT predicate on the "message" field.
+func MessageGT(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldMessage, v))
+}
+
+// MessageGTE applies the GTE predicate on the "message" field.
+func MessageGTE(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldMessage, v))
+}
+
+// MessageLT applies the LT predicate on the "message" field.
+func MessageLT(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldMessage, v))
+}
+
+// MessageLTE applies the LTE predicate on the "message" field.
+func MessageLTE(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldMessage, v))
+}
+
+// MessageContains applies the Contains predicate on the "message" field.
+func MessageContains(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldContains(FieldMessage, v))
+}
+
+// MessageHasPrefix applies the HasPrefix predicate on the "message" field.
+func MessageHasPrefix(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldHasPrefix(FieldMessage, v))
+}
+
+// MessageHasSuffix applies the HasSuffix predicate on the "message" field.
+func MessageHasSuffix(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldHasSuffix(FieldMessage, v))
+}
+
+// MessageIsNil applies the IsNil predicate on the "message" field.
+func MessageIsNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIsNull(FieldMessage))
+}
+
+// MessageNotNil applies the NotNil predicate on the "message" field.
+func MessageNotNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotNull(FieldMessage))
+}
+
+// MessageEqualFold applies the EqualFold predicate on the "message" field.
+func MessageEqualFold(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEqualFold(FieldMessage, v))
+}
+
+// MessageContainsFold applies the ContainsFold predicate on the "message" field.
+func MessageContainsFold(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldContainsFold(FieldMessage, v))
+}
+
+// CheckedAtEQ applies the EQ predicate on the "checked_at" field.
+func CheckedAtEQ(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldCheckedAt, v))
+}
+
+// CheckedAtNEQ applies the NEQ predicate on the "checked_at" field.
+func CheckedAtNEQ(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldCheckedAt, v))
+}
+
+// CheckedAtIn applies the In predicate on the "checked_at" field.
+func CheckedAtIn(vs ...time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldCheckedAt, vs...))
+}
+
+// CheckedAtNotIn applies the NotIn predicate on the "checked_at" field.
+func CheckedAtNotIn(vs ...time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldCheckedAt, vs...))
+}
+
+// CheckedAtGT applies the GT predicate on the "checked_at" field.
+func CheckedAtGT(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldCheckedAt, v))
+}
+
+// CheckedAtGTE applies the GTE predicate on the "checked_at" field.
+func CheckedAtGTE(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldCheckedAt, v))
+}
+
+// CheckedAtLT applies the LT predicate on the "checked_at" field.
+func CheckedAtLT(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldCheckedAt, v))
+}
+
+// CheckedAtLTE applies the LTE predicate on the "checked_at" field.
+func CheckedAtLTE(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldCheckedAt, v))
+}
+
+// HasMonitor applies the HasEdge predicate on the "monitor" edge.
+func HasMonitor() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasMonitorWith applies the HasEdge predicate on the "monitor" edge with a given conditions (other predicates).
+func HasMonitorWith(preds ...predicate.ChannelMonitor) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(func(s *sql.Selector) {
+ step := newMonitorStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.ChannelMonitorHistory) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.ChannelMonitorHistory) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.ChannelMonitorHistory) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.NotPredicates(p))
+}
diff --git a/backend/ent/channelmonitorhistory_create.go b/backend/ent/channelmonitorhistory_create.go
new file mode 100644
index 00000000..71034865
--- /dev/null
+++ b/backend/ent/channelmonitorhistory_create.go
@@ -0,0 +1,947 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+)
+
+// ChannelMonitorHistoryCreate is the builder for creating a ChannelMonitorHistory entity.
+type ChannelMonitorHistoryCreate struct {
+ config
+ mutation *ChannelMonitorHistoryMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_c *ChannelMonitorHistoryCreate) SetMonitorID(v int64) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetMonitorID(v)
+ return _c
+}
+
+// SetModel sets the "model" field.
+func (_c *ChannelMonitorHistoryCreate) SetModel(v string) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetModel(v)
+ return _c
+}
+
+// SetStatus sets the "status" field.
+func (_c *ChannelMonitorHistoryCreate) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetStatus(v)
+ return _c
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (_c *ChannelMonitorHistoryCreate) SetLatencyMs(v int) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetLatencyMs(v)
+ return _c
+}
+
+// SetNillableLatencyMs sets the "latency_ms" field if the given value is not nil.
+func (_c *ChannelMonitorHistoryCreate) SetNillableLatencyMs(v *int) *ChannelMonitorHistoryCreate {
+ if v != nil {
+ _c.SetLatencyMs(*v)
+ }
+ return _c
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (_c *ChannelMonitorHistoryCreate) SetPingLatencyMs(v int) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetPingLatencyMs(v)
+ return _c
+}
+
+// SetNillablePingLatencyMs sets the "ping_latency_ms" field if the given value is not nil.
+func (_c *ChannelMonitorHistoryCreate) SetNillablePingLatencyMs(v *int) *ChannelMonitorHistoryCreate {
+ if v != nil {
+ _c.SetPingLatencyMs(*v)
+ }
+ return _c
+}
+
+// SetMessage sets the "message" field.
+func (_c *ChannelMonitorHistoryCreate) SetMessage(v string) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetMessage(v)
+ return _c
+}
+
+// SetNillableMessage sets the "message" field if the given value is not nil.
+func (_c *ChannelMonitorHistoryCreate) SetNillableMessage(v *string) *ChannelMonitorHistoryCreate {
+ if v != nil {
+ _c.SetMessage(*v)
+ }
+ return _c
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (_c *ChannelMonitorHistoryCreate) SetCheckedAt(v time.Time) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetCheckedAt(v)
+ return _c
+}
+
+// SetNillableCheckedAt sets the "checked_at" field if the given value is not nil.
+func (_c *ChannelMonitorHistoryCreate) SetNillableCheckedAt(v *time.Time) *ChannelMonitorHistoryCreate {
+ if v != nil {
+ _c.SetCheckedAt(*v)
+ }
+ return _c
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_c *ChannelMonitorHistoryCreate) SetMonitor(v *ChannelMonitor) *ChannelMonitorHistoryCreate {
+ return _c.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorHistoryMutation object of the builder.
+func (_c *ChannelMonitorHistoryCreate) Mutation() *ChannelMonitorHistoryMutation {
+ return _c.mutation
+}
+
+// Save creates the ChannelMonitorHistory in the database.
+func (_c *ChannelMonitorHistoryCreate) Save(ctx context.Context) (*ChannelMonitorHistory, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *ChannelMonitorHistoryCreate) SaveX(ctx context.Context) *ChannelMonitorHistory {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorHistoryCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorHistoryCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *ChannelMonitorHistoryCreate) defaults() {
+ if _, ok := _c.mutation.Message(); !ok {
+ v := channelmonitorhistory.DefaultMessage
+ _c.mutation.SetMessage(v)
+ }
+ if _, ok := _c.mutation.CheckedAt(); !ok {
+ v := channelmonitorhistory.DefaultCheckedAt()
+ _c.mutation.SetCheckedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *ChannelMonitorHistoryCreate) check() error {
+ if _, ok := _c.mutation.MonitorID(); !ok {
+ return &ValidationError{Name: "monitor_id", err: errors.New(`ent: missing required field "ChannelMonitorHistory.monitor_id"`)}
+ }
+ if _, ok := _c.mutation.Model(); !ok {
+ return &ValidationError{Name: "model", err: errors.New(`ent: missing required field "ChannelMonitorHistory.model"`)}
+ }
+ if v, ok := _c.mutation.Model(); ok {
+ if err := channelmonitorhistory.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.model": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Status(); !ok {
+ return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "ChannelMonitorHistory.status"`)}
+ }
+ if v, ok := _c.mutation.Status(); ok {
+ if err := channelmonitorhistory.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.status": %w`, err)}
+ }
+ }
+ if v, ok := _c.mutation.Message(); ok {
+ if err := channelmonitorhistory.MessageValidator(v); err != nil {
+ return &ValidationError{Name: "message", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.message": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.CheckedAt(); !ok {
+ return &ValidationError{Name: "checked_at", err: errors.New(`ent: missing required field "ChannelMonitorHistory.checked_at"`)}
+ }
+ if len(_c.mutation.MonitorIDs()) == 0 {
+ return &ValidationError{Name: "monitor", err: errors.New(`ent: missing required edge "ChannelMonitorHistory.monitor"`)}
+ }
+ return nil
+}
+
+func (_c *ChannelMonitorHistoryCreate) sqlSave(ctx context.Context) (*ChannelMonitorHistory, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *ChannelMonitorHistoryCreate) createSpec() (*ChannelMonitorHistory, *sqlgraph.CreateSpec) {
+ var (
+ _node = &ChannelMonitorHistory{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(channelmonitorhistory.Table, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.Model(); ok {
+ _spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value)
+ _node.Model = value
+ }
+ if value, ok := _c.mutation.Status(); ok {
+ _spec.SetField(channelmonitorhistory.FieldStatus, field.TypeEnum, value)
+ _node.Status = value
+ }
+ if value, ok := _c.mutation.LatencyMs(); ok {
+ _spec.SetField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value)
+ _node.LatencyMs = &value
+ }
+ if value, ok := _c.mutation.PingLatencyMs(); ok {
+ _spec.SetField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value)
+ _node.PingLatencyMs = &value
+ }
+ if value, ok := _c.mutation.Message(); ok {
+ _spec.SetField(channelmonitorhistory.FieldMessage, field.TypeString, value)
+ _node.Message = value
+ }
+ if value, ok := _c.mutation.CheckedAt(); ok {
+ _spec.SetField(channelmonitorhistory.FieldCheckedAt, field.TypeTime, value)
+ _node.CheckedAt = value
+ }
+ if nodes := _c.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitorhistory.MonitorTable,
+ Columns: []string{channelmonitorhistory.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.MonitorID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorHistory.Create().
+// SetMonitorID(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorHistoryUpsert) {
+// SetMonitorID(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorHistoryCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorHistoryUpsertOne {
+ _c.conflict = opts
+ return &ChannelMonitorHistoryUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorHistory.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorHistoryCreate) OnConflictColumns(columns ...string) *ChannelMonitorHistoryUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorHistoryUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // ChannelMonitorHistoryUpsertOne is the builder for "upsert"-ing
+ // one ChannelMonitorHistory node.
+ ChannelMonitorHistoryUpsertOne struct {
+ create *ChannelMonitorHistoryCreate
+ }
+
+ // ChannelMonitorHistoryUpsert is the "OnConflict" setter.
+ ChannelMonitorHistoryUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorHistoryUpsert) SetMonitorID(v int64) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldMonitorID, v)
+ return u
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateMonitorID() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldMonitorID)
+ return u
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorHistoryUpsert) SetModel(v string) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldModel, v)
+ return u
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateModel() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldModel)
+ return u
+}
+
+// SetStatus sets the "status" field.
+func (u *ChannelMonitorHistoryUpsert) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldStatus, v)
+ return u
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateStatus() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldStatus)
+ return u
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsert) SetLatencyMs(v int) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldLatencyMs, v)
+ return u
+}
+
+// UpdateLatencyMs sets the "latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateLatencyMs() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldLatencyMs)
+ return u
+}
+
+// AddLatencyMs adds v to the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsert) AddLatencyMs(v int) *ChannelMonitorHistoryUpsert {
+ u.Add(channelmonitorhistory.FieldLatencyMs, v)
+ return u
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsert) ClearLatencyMs() *ChannelMonitorHistoryUpsert {
+ u.SetNull(channelmonitorhistory.FieldLatencyMs)
+ return u
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsert) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldPingLatencyMs, v)
+ return u
+}
+
+// UpdatePingLatencyMs sets the "ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdatePingLatencyMs() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldPingLatencyMs)
+ return u
+}
+
+// AddPingLatencyMs adds v to the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsert) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpsert {
+ u.Add(channelmonitorhistory.FieldPingLatencyMs, v)
+ return u
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsert) ClearPingLatencyMs() *ChannelMonitorHistoryUpsert {
+ u.SetNull(channelmonitorhistory.FieldPingLatencyMs)
+ return u
+}
+
+// SetMessage sets the "message" field.
+func (u *ChannelMonitorHistoryUpsert) SetMessage(v string) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldMessage, v)
+ return u
+}
+
+// UpdateMessage sets the "message" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateMessage() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldMessage)
+ return u
+}
+
+// ClearMessage clears the value of the "message" field.
+func (u *ChannelMonitorHistoryUpsert) ClearMessage() *ChannelMonitorHistoryUpsert {
+ u.SetNull(channelmonitorhistory.FieldMessage)
+ return u
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (u *ChannelMonitorHistoryUpsert) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldCheckedAt, v)
+ return u
+}
+
+// UpdateCheckedAt sets the "checked_at" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateCheckedAt() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldCheckedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorHistory.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorHistoryUpsertOne) UpdateNewValues() *ChannelMonitorHistoryUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorHistory.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorHistoryUpsertOne) Ignore() *ChannelMonitorHistoryUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorHistoryUpsertOne) DoNothing() *ChannelMonitorHistoryUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorHistoryCreate.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorHistoryUpsertOne) Update(set func(*ChannelMonitorHistoryUpsert)) *ChannelMonitorHistoryUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorHistoryUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetMonitorID(v int64) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetMonitorID(v)
+ })
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateMonitorID() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateMonitorID()
+ })
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetModel(v string) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetModel(v)
+ })
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateModel() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateModel()
+ })
+}
+
+// SetStatus sets the "status" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetStatus(v)
+ })
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateStatus() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateStatus()
+ })
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetLatencyMs(v int) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetLatencyMs(v)
+ })
+}
+
+// AddLatencyMs adds v to the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertOne) AddLatencyMs(v int) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.AddLatencyMs(v)
+ })
+}
+
+// UpdateLatencyMs sets the "latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateLatencyMs() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateLatencyMs()
+ })
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertOne) ClearLatencyMs() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearLatencyMs()
+ })
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetPingLatencyMs(v)
+ })
+}
+
+// AddPingLatencyMs adds v to the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertOne) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.AddPingLatencyMs(v)
+ })
+}
+
+// UpdatePingLatencyMs sets the "ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdatePingLatencyMs() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdatePingLatencyMs()
+ })
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertOne) ClearPingLatencyMs() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearPingLatencyMs()
+ })
+}
+
+// SetMessage sets the "message" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetMessage(v string) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetMessage(v)
+ })
+}
+
+// UpdateMessage sets the "message" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateMessage() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateMessage()
+ })
+}
+
+// ClearMessage clears the value of the "message" field.
+func (u *ChannelMonitorHistoryUpsertOne) ClearMessage() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearMessage()
+ })
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetCheckedAt(v)
+ })
+}
+
+// UpdateCheckedAt sets the "checked_at" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateCheckedAt() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateCheckedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorHistoryUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorHistoryCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorHistoryUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *ChannelMonitorHistoryUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *ChannelMonitorHistoryUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// ChannelMonitorHistoryCreateBulk is the builder for creating many ChannelMonitorHistory entities in bulk.
+type ChannelMonitorHistoryCreateBulk struct {
+ config
+ err error
+ builders []*ChannelMonitorHistoryCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the ChannelMonitorHistory entities in the database.
+func (_c *ChannelMonitorHistoryCreateBulk) Save(ctx context.Context) ([]*ChannelMonitorHistory, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*ChannelMonitorHistory, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*ChannelMonitorHistoryMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *ChannelMonitorHistoryCreateBulk) SaveX(ctx context.Context) []*ChannelMonitorHistory {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorHistoryCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorHistoryCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorHistory.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorHistoryUpsert) {
+// SetMonitorID(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorHistoryCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorHistoryUpsertBulk {
+ _c.conflict = opts
+ return &ChannelMonitorHistoryUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorHistory.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorHistoryCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorHistoryUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorHistoryUpsertBulk{
+ create: _c,
+ }
+}
+
+// ChannelMonitorHistoryUpsertBulk is the builder for "upsert"-ing
+// a bulk of ChannelMonitorHistory nodes.
+type ChannelMonitorHistoryUpsertBulk struct {
+ create *ChannelMonitorHistoryCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.ChannelMonitorHistory.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateNewValues() *ChannelMonitorHistoryUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorHistory.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorHistoryUpsertBulk) Ignore() *ChannelMonitorHistoryUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorHistoryUpsertBulk) DoNothing() *ChannelMonitorHistoryUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorHistoryCreateBulk.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorHistoryUpsertBulk) Update(set func(*ChannelMonitorHistoryUpsert)) *ChannelMonitorHistoryUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorHistoryUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetMonitorID(v int64) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetMonitorID(v)
+ })
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateMonitorID() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateMonitorID()
+ })
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetModel(v string) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetModel(v)
+ })
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateModel() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateModel()
+ })
+}
+
+// SetStatus sets the "status" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetStatus(v)
+ })
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateStatus() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateStatus()
+ })
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetLatencyMs(v int) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetLatencyMs(v)
+ })
+}
+
+// AddLatencyMs adds v to the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertBulk) AddLatencyMs(v int) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.AddLatencyMs(v)
+ })
+}
+
+// UpdateLatencyMs sets the "latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateLatencyMs() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateLatencyMs()
+ })
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertBulk) ClearLatencyMs() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearLatencyMs()
+ })
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetPingLatencyMs(v)
+ })
+}
+
+// AddPingLatencyMs adds v to the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertBulk) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.AddPingLatencyMs(v)
+ })
+}
+
+// UpdatePingLatencyMs sets the "ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdatePingLatencyMs() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdatePingLatencyMs()
+ })
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertBulk) ClearPingLatencyMs() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearPingLatencyMs()
+ })
+}
+
+// SetMessage sets the "message" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetMessage(v string) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetMessage(v)
+ })
+}
+
+// UpdateMessage sets the "message" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateMessage() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateMessage()
+ })
+}
+
+// ClearMessage clears the value of the "message" field.
+func (u *ChannelMonitorHistoryUpsertBulk) ClearMessage() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearMessage()
+ })
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetCheckedAt(v)
+ })
+}
+
+// UpdateCheckedAt sets the "checked_at" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateCheckedAt() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateCheckedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorHistoryUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorHistoryCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorHistoryCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorHistoryUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitorhistory_delete.go b/backend/ent/channelmonitorhistory_delete.go
new file mode 100644
index 00000000..97110e69
--- /dev/null
+++ b/backend/ent/channelmonitorhistory_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorHistoryDelete is the builder for deleting a ChannelMonitorHistory entity.
+type ChannelMonitorHistoryDelete struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorHistoryMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorHistoryDelete builder.
+func (_d *ChannelMonitorHistoryDelete) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *ChannelMonitorHistoryDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorHistoryDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *ChannelMonitorHistoryDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(channelmonitorhistory.Table, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// ChannelMonitorHistoryDeleteOne is the builder for deleting a single ChannelMonitorHistory entity.
+type ChannelMonitorHistoryDeleteOne struct {
+ _d *ChannelMonitorHistoryDelete
+}
+
+// Where appends a list predicates to the ChannelMonitorHistoryDelete builder.
+func (_d *ChannelMonitorHistoryDeleteOne) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *ChannelMonitorHistoryDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{channelmonitorhistory.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorHistoryDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitorhistory_query.go b/backend/ent/channelmonitorhistory_query.go
new file mode 100644
index 00000000..1fb872ad
--- /dev/null
+++ b/backend/ent/channelmonitorhistory_query.go
@@ -0,0 +1,643 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorHistoryQuery is the builder for querying ChannelMonitorHistory entities.
+type ChannelMonitorHistoryQuery struct {
+ config
+ ctx *QueryContext
+ order []channelmonitorhistory.OrderOption
+ inters []Interceptor
+ predicates []predicate.ChannelMonitorHistory
+ withMonitor *ChannelMonitorQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the ChannelMonitorHistoryQuery builder.
+func (_q *ChannelMonitorHistoryQuery) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *ChannelMonitorHistoryQuery) Limit(limit int) *ChannelMonitorHistoryQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *ChannelMonitorHistoryQuery) Offset(offset int) *ChannelMonitorHistoryQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *ChannelMonitorHistoryQuery) Unique(unique bool) *ChannelMonitorHistoryQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *ChannelMonitorHistoryQuery) Order(o ...channelmonitorhistory.OrderOption) *ChannelMonitorHistoryQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryMonitor chains the current query on the "monitor" edge.
+func (_q *ChannelMonitorHistoryQuery) QueryMonitor() *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitorhistory.Table, channelmonitorhistory.FieldID, selector),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, channelmonitorhistory.MonitorTable, channelmonitorhistory.MonitorColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first ChannelMonitorHistory entity from the query.
+// Returns a *NotFoundError when no ChannelMonitorHistory was found.
+func (_q *ChannelMonitorHistoryQuery) First(ctx context.Context) (*ChannelMonitorHistory, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{channelmonitorhistory.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) FirstX(ctx context.Context) *ChannelMonitorHistory {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first ChannelMonitorHistory ID from the query.
+// Returns a *NotFoundError when no ChannelMonitorHistory ID was found.
+func (_q *ChannelMonitorHistoryQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{channelmonitorhistory.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single ChannelMonitorHistory entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one ChannelMonitorHistory entity is found.
+// Returns a *NotFoundError when no ChannelMonitorHistory entities are found.
+func (_q *ChannelMonitorHistoryQuery) Only(ctx context.Context) (*ChannelMonitorHistory, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{channelmonitorhistory.Label}
+ default:
+ return nil, &NotSingularError{channelmonitorhistory.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) OnlyX(ctx context.Context) *ChannelMonitorHistory {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only ChannelMonitorHistory ID in the query.
+// Returns a *NotSingularError when more than one ChannelMonitorHistory ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *ChannelMonitorHistoryQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{channelmonitorhistory.Label}
+ default:
+ err = &NotSingularError{channelmonitorhistory.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of ChannelMonitorHistories.
+func (_q *ChannelMonitorHistoryQuery) All(ctx context.Context) ([]*ChannelMonitorHistory, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*ChannelMonitorHistory, *ChannelMonitorHistoryQuery]()
+ return withInterceptors[[]*ChannelMonitorHistory](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) AllX(ctx context.Context) []*ChannelMonitorHistory {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of ChannelMonitorHistory IDs.
+func (_q *ChannelMonitorHistoryQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(channelmonitorhistory.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *ChannelMonitorHistoryQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorHistoryQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *ChannelMonitorHistoryQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the ChannelMonitorHistoryQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *ChannelMonitorHistoryQuery) Clone() *ChannelMonitorHistoryQuery {
+ if _q == nil {
+ return nil
+ }
+ return &ChannelMonitorHistoryQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]channelmonitorhistory.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.ChannelMonitorHistory{}, _q.predicates...),
+ withMonitor: _q.withMonitor.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithMonitor tells the query-builder to eager-load the nodes that are connected to
+// the "monitor" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorHistoryQuery) WithMonitor(opts ...func(*ChannelMonitorQuery)) *ChannelMonitorHistoryQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withMonitor = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// MonitorID int64 `json:"monitor_id,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.ChannelMonitorHistory.Query().
+// GroupBy(channelmonitorhistory.FieldMonitorID).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorHistoryQuery) GroupBy(field string, fields ...string) *ChannelMonitorHistoryGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &ChannelMonitorHistoryGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = channelmonitorhistory.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// MonitorID int64 `json:"monitor_id,omitempty"`
+// }
+//
+// client.ChannelMonitorHistory.Query().
+// Select(channelmonitorhistory.FieldMonitorID).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorHistoryQuery) Select(fields ...string) *ChannelMonitorHistorySelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &ChannelMonitorHistorySelect{ChannelMonitorHistoryQuery: _q}
+ sbuild.label = channelmonitorhistory.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a ChannelMonitorHistorySelect configured with the given aggregations.
+func (_q *ChannelMonitorHistoryQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorHistorySelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *ChannelMonitorHistoryQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !channelmonitorhistory.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorHistoryQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitorHistory, error) {
+ var (
+ nodes = []*ChannelMonitorHistory{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withMonitor != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*ChannelMonitorHistory).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &ChannelMonitorHistory{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withMonitor; query != nil {
+ if err := _q.loadMonitor(ctx, query, nodes, nil,
+ func(n *ChannelMonitorHistory, e *ChannelMonitor) { n.Edges.Monitor = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *ChannelMonitorHistoryQuery) loadMonitor(ctx context.Context, query *ChannelMonitorQuery, nodes []*ChannelMonitorHistory, init func(*ChannelMonitorHistory), assign func(*ChannelMonitorHistory, *ChannelMonitor)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*ChannelMonitorHistory)
+ for i := range nodes {
+ fk := nodes[i].MonitorID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(channelmonitor.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "monitor_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorHistoryQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *ChannelMonitorHistoryQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(channelmonitorhistory.Table, channelmonitorhistory.Columns, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorhistory.FieldID)
+ for i := range fields {
+ if fields[i] != channelmonitorhistory.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withMonitor != nil {
+ _spec.Node.AddColumnOnce(channelmonitorhistory.FieldMonitorID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *ChannelMonitorHistoryQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(channelmonitorhistory.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = channelmonitorhistory.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *ChannelMonitorHistoryQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorHistoryQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *ChannelMonitorHistoryQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorHistoryQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// ChannelMonitorHistoryGroupBy is the group-by builder for ChannelMonitorHistory entities.
+type ChannelMonitorHistoryGroupBy struct {
+ selector
+ build *ChannelMonitorHistoryQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *ChannelMonitorHistoryGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorHistoryGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *ChannelMonitorHistoryGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorHistoryQuery, *ChannelMonitorHistoryGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *ChannelMonitorHistoryGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorHistoryQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// ChannelMonitorHistorySelect is the builder for selecting fields of ChannelMonitorHistory entities.
+type ChannelMonitorHistorySelect struct {
+ *ChannelMonitorHistoryQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *ChannelMonitorHistorySelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorHistorySelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *ChannelMonitorHistorySelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorHistoryQuery, *ChannelMonitorHistorySelect](ctx, _s.ChannelMonitorHistoryQuery, _s, _s.inters, v)
+}
+
+func (_s *ChannelMonitorHistorySelect) sqlScan(ctx context.Context, root *ChannelMonitorHistoryQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/channelmonitorhistory_update.go b/backend/ent/channelmonitorhistory_update.go
new file mode 100644
index 00000000..a85a8072
--- /dev/null
+++ b/backend/ent/channelmonitorhistory_update.go
@@ -0,0 +1,635 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorHistoryUpdate is the builder for updating ChannelMonitorHistory entities.
+type ChannelMonitorHistoryUpdate struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorHistoryMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorHistoryUpdate builder.
+func (_u *ChannelMonitorHistoryUpdate) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_u *ChannelMonitorHistoryUpdate) SetMonitorID(v int64) *ChannelMonitorHistoryUpdate {
+ _u.mutation.SetMonitorID(v)
+ return _u
+}
+
+// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableMonitorID(v *int64) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetMonitorID(*v)
+ }
+ return _u
+}
+
+// SetModel sets the "model" field.
+func (_u *ChannelMonitorHistoryUpdate) SetModel(v string) *ChannelMonitorHistoryUpdate {
+ _u.mutation.SetModel(v)
+ return _u
+}
+
+// SetNillableModel sets the "model" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableModel(v *string) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetModel(*v)
+ }
+ return _u
+}
+
+// SetStatus sets the "status" field.
+func (_u *ChannelMonitorHistoryUpdate) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpdate {
+ _u.mutation.SetStatus(v)
+ return _u
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableStatus(v *channelmonitorhistory.Status) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetStatus(*v)
+ }
+ return _u
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdate) SetLatencyMs(v int) *ChannelMonitorHistoryUpdate {
+ _u.mutation.ResetLatencyMs()
+ _u.mutation.SetLatencyMs(v)
+ return _u
+}
+
+// SetNillableLatencyMs sets the "latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableLatencyMs(v *int) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddLatencyMs adds value to the "latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdate) AddLatencyMs(v int) *ChannelMonitorHistoryUpdate {
+ _u.mutation.AddLatencyMs(v)
+ return _u
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdate) ClearLatencyMs() *ChannelMonitorHistoryUpdate {
+ _u.mutation.ClearLatencyMs()
+ return _u
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdate) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpdate {
+ _u.mutation.ResetPingLatencyMs()
+ _u.mutation.SetPingLatencyMs(v)
+ return _u
+}
+
+// SetNillablePingLatencyMs sets the "ping_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillablePingLatencyMs(v *int) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetPingLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddPingLatencyMs adds value to the "ping_latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdate) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpdate {
+ _u.mutation.AddPingLatencyMs(v)
+ return _u
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdate) ClearPingLatencyMs() *ChannelMonitorHistoryUpdate {
+ _u.mutation.ClearPingLatencyMs()
+ return _u
+}
+
+// SetMessage sets the "message" field.
+func (_u *ChannelMonitorHistoryUpdate) SetMessage(v string) *ChannelMonitorHistoryUpdate {
+ _u.mutation.SetMessage(v)
+ return _u
+}
+
+// SetNillableMessage sets the "message" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableMessage(v *string) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetMessage(*v)
+ }
+ return _u
+}
+
+// ClearMessage clears the value of the "message" field.
+func (_u *ChannelMonitorHistoryUpdate) ClearMessage() *ChannelMonitorHistoryUpdate {
+ _u.mutation.ClearMessage()
+ return _u
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (_u *ChannelMonitorHistoryUpdate) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpdate {
+ _u.mutation.SetCheckedAt(v)
+ return _u
+}
+
+// SetNillableCheckedAt sets the "checked_at" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableCheckedAt(v *time.Time) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetCheckedAt(*v)
+ }
+ return _u
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorHistoryUpdate) SetMonitor(v *ChannelMonitor) *ChannelMonitorHistoryUpdate {
+ return _u.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorHistoryMutation object of the builder.
+func (_u *ChannelMonitorHistoryUpdate) Mutation() *ChannelMonitorHistoryMutation {
+ return _u.mutation
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorHistoryUpdate) ClearMonitor() *ChannelMonitorHistoryUpdate {
+ _u.mutation.ClearMonitor()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *ChannelMonitorHistoryUpdate) Save(ctx context.Context) (int, error) {
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorHistoryUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *ChannelMonitorHistoryUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorHistoryUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorHistoryUpdate) check() error {
+ if v, ok := _u.mutation.Model(); ok {
+ if err := channelmonitorhistory.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.model": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Status(); ok {
+ if err := channelmonitorhistory.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.status": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Message(); ok {
+ if err := channelmonitorhistory.MessageValidator(v); err != nil {
+ return &ValidationError{Name: "message", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.message": %w`, err)}
+ }
+ }
+ if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "ChannelMonitorHistory.monitor"`)
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorHistoryUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitorhistory.Table, channelmonitorhistory.Columns, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.Model(); ok {
+ _spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Status(); ok {
+ _spec.SetField(channelmonitorhistory.FieldStatus, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.LatencyMs(); ok {
+ _spec.SetField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedLatencyMs(); ok {
+ _spec.AddField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value)
+ }
+ if _u.mutation.LatencyMsCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldLatencyMs, field.TypeInt)
+ }
+ if value, ok := _u.mutation.PingLatencyMs(); ok {
+ _spec.SetField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedPingLatencyMs(); ok {
+ _spec.AddField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value)
+ }
+ if _u.mutation.PingLatencyMsCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt)
+ }
+ if value, ok := _u.mutation.Message(); ok {
+ _spec.SetField(channelmonitorhistory.FieldMessage, field.TypeString, value)
+ }
+ if _u.mutation.MessageCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldMessage, field.TypeString)
+ }
+ if value, ok := _u.mutation.CheckedAt(); ok {
+ _spec.SetField(channelmonitorhistory.FieldCheckedAt, field.TypeTime, value)
+ }
+ if _u.mutation.MonitorCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitorhistory.MonitorTable,
+ Columns: []string{channelmonitorhistory.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitorhistory.MonitorTable,
+ Columns: []string{channelmonitorhistory.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitorhistory.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// ChannelMonitorHistoryUpdateOne is the builder for updating a single ChannelMonitorHistory entity.
+type ChannelMonitorHistoryUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *ChannelMonitorHistoryMutation
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetMonitorID(v int64) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.SetMonitorID(v)
+ return _u
+}
+
+// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableMonitorID(v *int64) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetMonitorID(*v)
+ }
+ return _u
+}
+
+// SetModel sets the "model" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetModel(v string) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.SetModel(v)
+ return _u
+}
+
+// SetNillableModel sets the "model" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableModel(v *string) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetModel(*v)
+ }
+ return _u
+}
+
+// SetStatus sets the "status" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.SetStatus(v)
+ return _u
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableStatus(v *channelmonitorhistory.Status) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetStatus(*v)
+ }
+ return _u
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetLatencyMs(v int) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ResetLatencyMs()
+ _u.mutation.SetLatencyMs(v)
+ return _u
+}
+
+// SetNillableLatencyMs sets the "latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableLatencyMs(v *int) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddLatencyMs adds value to the "latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdateOne) AddLatencyMs(v int) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.AddLatencyMs(v)
+ return _u
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdateOne) ClearLatencyMs() *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ClearLatencyMs()
+ return _u
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ResetPingLatencyMs()
+ _u.mutation.SetPingLatencyMs(v)
+ return _u
+}
+
+// SetNillablePingLatencyMs sets the "ping_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillablePingLatencyMs(v *int) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetPingLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddPingLatencyMs adds value to the "ping_latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdateOne) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.AddPingLatencyMs(v)
+ return _u
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdateOne) ClearPingLatencyMs() *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ClearPingLatencyMs()
+ return _u
+}
+
+// SetMessage sets the "message" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetMessage(v string) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.SetMessage(v)
+ return _u
+}
+
+// SetNillableMessage sets the "message" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableMessage(v *string) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetMessage(*v)
+ }
+ return _u
+}
+
+// ClearMessage clears the value of the "message" field.
+func (_u *ChannelMonitorHistoryUpdateOne) ClearMessage() *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ClearMessage()
+ return _u
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.SetCheckedAt(v)
+ return _u
+}
+
+// SetNillableCheckedAt sets the "checked_at" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableCheckedAt(v *time.Time) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetCheckedAt(*v)
+ }
+ return _u
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorHistoryUpdateOne) SetMonitor(v *ChannelMonitor) *ChannelMonitorHistoryUpdateOne {
+ return _u.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorHistoryMutation object of the builder.
+func (_u *ChannelMonitorHistoryUpdateOne) Mutation() *ChannelMonitorHistoryMutation {
+ return _u.mutation
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorHistoryUpdateOne) ClearMonitor() *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ClearMonitor()
+ return _u
+}
+
+// Where appends a list predicates to the ChannelMonitorHistoryUpdate builder.
+func (_u *ChannelMonitorHistoryUpdateOne) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *ChannelMonitorHistoryUpdateOne) Select(field string, fields ...string) *ChannelMonitorHistoryUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated ChannelMonitorHistory entity.
+func (_u *ChannelMonitorHistoryUpdateOne) Save(ctx context.Context) (*ChannelMonitorHistory, error) {
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorHistoryUpdateOne) SaveX(ctx context.Context) *ChannelMonitorHistory {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *ChannelMonitorHistoryUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorHistoryUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorHistoryUpdateOne) check() error {
+ if v, ok := _u.mutation.Model(); ok {
+ if err := channelmonitorhistory.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.model": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Status(); ok {
+ if err := channelmonitorhistory.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.status": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Message(); ok {
+ if err := channelmonitorhistory.MessageValidator(v); err != nil {
+ return &ValidationError{Name: "message", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.message": %w`, err)}
+ }
+ }
+ if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "ChannelMonitorHistory.monitor"`)
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorHistoryUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitorHistory, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitorhistory.Table, channelmonitorhistory.Columns, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitorHistory.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorhistory.FieldID)
+ for _, f := range fields {
+ if !channelmonitorhistory.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != channelmonitorhistory.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.Model(); ok {
+ _spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Status(); ok {
+ _spec.SetField(channelmonitorhistory.FieldStatus, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.LatencyMs(); ok {
+ _spec.SetField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedLatencyMs(); ok {
+ _spec.AddField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value)
+ }
+ if _u.mutation.LatencyMsCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldLatencyMs, field.TypeInt)
+ }
+ if value, ok := _u.mutation.PingLatencyMs(); ok {
+ _spec.SetField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedPingLatencyMs(); ok {
+ _spec.AddField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value)
+ }
+ if _u.mutation.PingLatencyMsCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt)
+ }
+ if value, ok := _u.mutation.Message(); ok {
+ _spec.SetField(channelmonitorhistory.FieldMessage, field.TypeString, value)
+ }
+ if _u.mutation.MessageCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldMessage, field.TypeString)
+ }
+ if value, ok := _u.mutation.CheckedAt(); ok {
+ _spec.SetField(channelmonitorhistory.FieldCheckedAt, field.TypeTime, value)
+ }
+ if _u.mutation.MonitorCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitorhistory.MonitorTable,
+ Columns: []string{channelmonitorhistory.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitorhistory.MonitorTable,
+ Columns: []string{channelmonitorhistory.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &ChannelMonitorHistory{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitorhistory.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/channelmonitorrequesttemplate.go b/backend/ent/channelmonitorrequesttemplate.go
new file mode 100644
index 00000000..b8429a4d
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate.go
@@ -0,0 +1,216 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+)
+
+// ChannelMonitorRequestTemplate is the model entity for the ChannelMonitorRequestTemplate schema.
+type ChannelMonitorRequestTemplate struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // Name holds the value of the "name" field.
+ Name string `json:"name,omitempty"`
+ // Provider holds the value of the "provider" field.
+ Provider channelmonitorrequesttemplate.Provider `json:"provider,omitempty"`
+ // Description holds the value of the "description" field.
+ Description string `json:"description,omitempty"`
+ // ExtraHeaders holds the value of the "extra_headers" field.
+ ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
+ // BodyOverrideMode holds the value of the "body_override_mode" field.
+ BodyOverrideMode string `json:"body_override_mode,omitempty"`
+ // BodyOverride holds the value of the "body_override" field.
+ BodyOverride map[string]interface{} `json:"body_override,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the ChannelMonitorRequestTemplateQuery when eager-loading is set.
+ Edges ChannelMonitorRequestTemplateEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// ChannelMonitorRequestTemplateEdges holds the relations/edges for other nodes in the graph.
+type ChannelMonitorRequestTemplateEdges struct {
+ // Monitors holds the value of the monitors edge.
+ Monitors []*ChannelMonitor `json:"monitors,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// MonitorsOrErr returns the Monitors value or an error if the edge
+// was not loaded in eager-loading.
+func (e ChannelMonitorRequestTemplateEdges) MonitorsOrErr() ([]*ChannelMonitor, error) {
+ if e.loadedTypes[0] {
+ return e.Monitors, nil
+ }
+ return nil, &NotLoadedError{edge: "monitors"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*ChannelMonitorRequestTemplate) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitorrequesttemplate.FieldExtraHeaders, channelmonitorrequesttemplate.FieldBodyOverride:
+ values[i] = new([]byte)
+ case channelmonitorrequesttemplate.FieldID:
+ values[i] = new(sql.NullInt64)
+ case channelmonitorrequesttemplate.FieldName, channelmonitorrequesttemplate.FieldProvider, channelmonitorrequesttemplate.FieldDescription, channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ values[i] = new(sql.NullString)
+ case channelmonitorrequesttemplate.FieldCreatedAt, channelmonitorrequesttemplate.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the ChannelMonitorRequestTemplate fields.
+func (_m *ChannelMonitorRequestTemplate) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitorrequesttemplate.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case channelmonitorrequesttemplate.FieldName:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field name", values[i])
+ } else if value.Valid {
+ _m.Name = value.String
+ }
+ case channelmonitorrequesttemplate.FieldProvider:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider", values[i])
+ } else if value.Valid {
+ _m.Provider = channelmonitorrequesttemplate.Provider(value.String)
+ }
+ case channelmonitorrequesttemplate.FieldDescription:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field description", values[i])
+ } else if value.Valid {
+ _m.Description = value.String
+ }
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field extra_headers", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ExtraHeaders); err != nil {
+ return fmt.Errorf("unmarshal field extra_headers: %w", err)
+ }
+ }
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field body_override_mode", values[i])
+ } else if value.Valid {
+ _m.BodyOverrideMode = value.String
+ }
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field body_override", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.BodyOverride); err != nil {
+ return fmt.Errorf("unmarshal field body_override: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitorRequestTemplate.
+// This includes values selected through modifiers, order, etc.
+func (_m *ChannelMonitorRequestTemplate) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryMonitors queries the "monitors" edge of the ChannelMonitorRequestTemplate entity.
+func (_m *ChannelMonitorRequestTemplate) QueryMonitors() *ChannelMonitorQuery {
+ return NewChannelMonitorRequestTemplateClient(_m.config).QueryMonitors(_m)
+}
+
+// Update returns a builder for updating this ChannelMonitorRequestTemplate.
+// Note that you need to call ChannelMonitorRequestTemplate.Unwrap() before calling this method if this ChannelMonitorRequestTemplate
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *ChannelMonitorRequestTemplate) Update() *ChannelMonitorRequestTemplateUpdateOne {
+ return NewChannelMonitorRequestTemplateClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the ChannelMonitorRequestTemplate entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *ChannelMonitorRequestTemplate) Unwrap() *ChannelMonitorRequestTemplate {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: ChannelMonitorRequestTemplate is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *ChannelMonitorRequestTemplate) String() string {
+ var builder strings.Builder
+ builder.WriteString("ChannelMonitorRequestTemplate(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("name=")
+ builder.WriteString(_m.Name)
+ builder.WriteString(", ")
+ builder.WriteString("provider=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Provider))
+ builder.WriteString(", ")
+ builder.WriteString("description=")
+ builder.WriteString(_m.Description)
+ builder.WriteString(", ")
+ builder.WriteString("extra_headers=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ExtraHeaders))
+ builder.WriteString(", ")
+ builder.WriteString("body_override_mode=")
+ builder.WriteString(_m.BodyOverrideMode)
+ builder.WriteString(", ")
+ builder.WriteString("body_override=")
+ builder.WriteString(fmt.Sprintf("%v", _m.BodyOverride))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// ChannelMonitorRequestTemplates is a parsable slice of ChannelMonitorRequestTemplate.
+type ChannelMonitorRequestTemplates []*ChannelMonitorRequestTemplate
diff --git a/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go b/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go
new file mode 100644
index 00000000..65b8d641
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go
@@ -0,0 +1,172 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitorrequesttemplate
+
+import (
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the channelmonitorrequesttemplate type in the database.
+ Label = "channel_monitor_request_template"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldName holds the string denoting the name field in the database.
+ FieldName = "name"
+ // FieldProvider holds the string denoting the provider field in the database.
+ FieldProvider = "provider"
+ // FieldDescription holds the string denoting the description field in the database.
+ FieldDescription = "description"
+ // FieldExtraHeaders holds the string denoting the extra_headers field in the database.
+ FieldExtraHeaders = "extra_headers"
+ // FieldBodyOverrideMode holds the string denoting the body_override_mode field in the database.
+ FieldBodyOverrideMode = "body_override_mode"
+ // FieldBodyOverride holds the string denoting the body_override field in the database.
+ FieldBodyOverride = "body_override"
+ // EdgeMonitors holds the string denoting the monitors edge name in mutations.
+ EdgeMonitors = "monitors"
+ // Table holds the table name of the channelmonitorrequesttemplate in the database.
+ Table = "channel_monitor_request_templates"
+ // MonitorsTable is the table that holds the monitors relation/edge.
+ MonitorsTable = "channel_monitors"
+ // MonitorsInverseTable is the table name for the ChannelMonitor entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitor" package.
+ MonitorsInverseTable = "channel_monitors"
+ // MonitorsColumn is the table column denoting the monitors relation/edge.
+ MonitorsColumn = "template_id"
+)
+
+// Columns holds all SQL columns for channelmonitorrequesttemplate fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldName,
+ FieldProvider,
+ FieldDescription,
+ FieldExtraHeaders,
+ FieldBodyOverrideMode,
+ FieldBodyOverride,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // NameValidator is a validator for the "name" field. It is called by the builders before save.
+ NameValidator func(string) error
+ // DefaultDescription holds the default value on creation for the "description" field.
+ DefaultDescription string
+ // DescriptionValidator is a validator for the "description" field. It is called by the builders before save.
+ DescriptionValidator func(string) error
+ // DefaultExtraHeaders holds the default value on creation for the "extra_headers" field.
+ DefaultExtraHeaders map[string]string
+ // DefaultBodyOverrideMode holds the default value on creation for the "body_override_mode" field.
+ DefaultBodyOverrideMode string
+ // BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
+ BodyOverrideModeValidator func(string) error
+)
+
+// Provider defines the type for the "provider" enum field.
+type Provider string
+
+// Provider values.
+const (
+ ProviderOpenai Provider = "openai"
+ ProviderAnthropic Provider = "anthropic"
+ ProviderGemini Provider = "gemini"
+)
+
+func (pr Provider) String() string {
+ return string(pr)
+}
+
+// ProviderValidator is a validator for the "provider" field enum values. It is called by the builders before save.
+func ProviderValidator(pr Provider) error {
+ switch pr {
+ case ProviderOpenai, ProviderAnthropic, ProviderGemini:
+ return nil
+ default:
+ return fmt.Errorf("channelmonitorrequesttemplate: invalid enum value for provider field: %q", pr)
+ }
+}
+
+// OrderOption defines the ordering options for the ChannelMonitorRequestTemplate queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByName orders the results by the name field.
+func ByName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldName, opts...).ToFunc()
+}
+
+// ByProvider orders the results by the provider field.
+func ByProvider(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProvider, opts...).ToFunc()
+}
+
+// ByDescription orders the results by the description field.
+func ByDescription(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDescription, opts...).ToFunc()
+}
+
+// ByBodyOverrideMode orders the results by the body_override_mode field.
+func ByBodyOverrideMode(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBodyOverrideMode, opts...).ToFunc()
+}
+
+// ByMonitorsCount orders the results by monitors count.
+func ByMonitorsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newMonitorsStep(), opts...)
+ }
+}
+
+// ByMonitors orders the results by monitors terms.
+func ByMonitors(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newMonitorsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+func newMonitorsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(MonitorsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, true, MonitorsTable, MonitorsColumn),
+ )
+}
diff --git a/backend/ent/channelmonitorrequesttemplate/where.go b/backend/ent/channelmonitorrequesttemplate/where.go
new file mode 100644
index 00000000..b95e5df0
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate/where.go
@@ -0,0 +1,434 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitorrequesttemplate
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
+func Name(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldName, v))
+}
+
+// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
+func Description(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v))
+}
+
+// BodyOverrideMode applies equality check predicate on the "body_override_mode" field. It's identical to BodyOverrideModeEQ.
+func BodyOverrideMode(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldBodyOverrideMode, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// NameEQ applies the EQ predicate on the "name" field.
+func NameEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldName, v))
+}
+
+// NameNEQ applies the NEQ predicate on the "name" field.
+func NameNEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldName, v))
+}
+
+// NameIn applies the In predicate on the "name" field.
+func NameIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldName, vs...))
+}
+
+// NameNotIn applies the NotIn predicate on the "name" field.
+func NameNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldName, vs...))
+}
+
+// NameGT applies the GT predicate on the "name" field.
+func NameGT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldName, v))
+}
+
+// NameGTE applies the GTE predicate on the "name" field.
+func NameGTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldName, v))
+}
+
+// NameLT applies the LT predicate on the "name" field.
+func NameLT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldName, v))
+}
+
+// NameLTE applies the LTE predicate on the "name" field.
+func NameLTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldName, v))
+}
+
+// NameContains applies the Contains predicate on the "name" field.
+func NameContains(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldName, v))
+}
+
+// NameHasPrefix applies the HasPrefix predicate on the "name" field.
+func NameHasPrefix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldName, v))
+}
+
+// NameHasSuffix applies the HasSuffix predicate on the "name" field.
+func NameHasSuffix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldName, v))
+}
+
+// NameEqualFold applies the EqualFold predicate on the "name" field.
+func NameEqualFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldName, v))
+}
+
+// NameContainsFold applies the ContainsFold predicate on the "name" field.
+func NameContainsFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldName, v))
+}
+
+// ProviderEQ applies the EQ predicate on the "provider" field.
+func ProviderEQ(v Provider) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldProvider, v))
+}
+
+// ProviderNEQ applies the NEQ predicate on the "provider" field.
+func ProviderNEQ(v Provider) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldProvider, v))
+}
+
+// ProviderIn applies the In predicate on the "provider" field.
+func ProviderIn(vs ...Provider) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldProvider, vs...))
+}
+
+// ProviderNotIn applies the NotIn predicate on the "provider" field.
+func ProviderNotIn(vs ...Provider) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldProvider, vs...))
+}
+
+// DescriptionEQ applies the EQ predicate on the "description" field.
+func DescriptionEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v))
+}
+
+// DescriptionNEQ applies the NEQ predicate on the "description" field.
+func DescriptionNEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldDescription, v))
+}
+
+// DescriptionIn applies the In predicate on the "description" field.
+func DescriptionIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldDescription, vs...))
+}
+
+// DescriptionNotIn applies the NotIn predicate on the "description" field.
+func DescriptionNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldDescription, vs...))
+}
+
+// DescriptionGT applies the GT predicate on the "description" field.
+func DescriptionGT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldDescription, v))
+}
+
+// DescriptionGTE applies the GTE predicate on the "description" field.
+func DescriptionGTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldDescription, v))
+}
+
+// DescriptionLT applies the LT predicate on the "description" field.
+func DescriptionLT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldDescription, v))
+}
+
+// DescriptionLTE applies the LTE predicate on the "description" field.
+func DescriptionLTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldDescription, v))
+}
+
+// DescriptionContains applies the Contains predicate on the "description" field.
+func DescriptionContains(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldDescription, v))
+}
+
+// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field.
+func DescriptionHasPrefix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldDescription, v))
+}
+
+// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field.
+func DescriptionHasSuffix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldDescription, v))
+}
+
+// DescriptionIsNil applies the IsNil predicate on the "description" field.
+func DescriptionIsNil() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIsNull(FieldDescription))
+}
+
+// DescriptionNotNil applies the NotNil predicate on the "description" field.
+func DescriptionNotNil() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotNull(FieldDescription))
+}
+
+// DescriptionEqualFold applies the EqualFold predicate on the "description" field.
+func DescriptionEqualFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldDescription, v))
+}
+
+// DescriptionContainsFold applies the ContainsFold predicate on the "description" field.
+func DescriptionContainsFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldDescription, v))
+}
+
+// BodyOverrideModeEQ applies the EQ predicate on the "body_override_mode" field.
+func BodyOverrideModeEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeNEQ applies the NEQ predicate on the "body_override_mode" field.
+func BodyOverrideModeNEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeIn applies the In predicate on the "body_override_mode" field.
+func BodyOverrideModeIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldBodyOverrideMode, vs...))
+}
+
+// BodyOverrideModeNotIn applies the NotIn predicate on the "body_override_mode" field.
+func BodyOverrideModeNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldBodyOverrideMode, vs...))
+}
+
+// BodyOverrideModeGT applies the GT predicate on the "body_override_mode" field.
+func BodyOverrideModeGT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeGTE applies the GTE predicate on the "body_override_mode" field.
+func BodyOverrideModeGTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeLT applies the LT predicate on the "body_override_mode" field.
+func BodyOverrideModeLT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeLTE applies the LTE predicate on the "body_override_mode" field.
+func BodyOverrideModeLTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeContains applies the Contains predicate on the "body_override_mode" field.
+func BodyOverrideModeContains(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeHasPrefix applies the HasPrefix predicate on the "body_override_mode" field.
+func BodyOverrideModeHasPrefix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeHasSuffix applies the HasSuffix predicate on the "body_override_mode" field.
+func BodyOverrideModeHasSuffix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeEqualFold applies the EqualFold predicate on the "body_override_mode" field.
+func BodyOverrideModeEqualFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeContainsFold applies the ContainsFold predicate on the "body_override_mode" field.
+func BodyOverrideModeContainsFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideIsNil applies the IsNil predicate on the "body_override" field.
+func BodyOverrideIsNil() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIsNull(FieldBodyOverride))
+}
+
+// BodyOverrideNotNil applies the NotNil predicate on the "body_override" field.
+func BodyOverrideNotNil() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotNull(FieldBodyOverride))
+}
+
+// HasMonitors applies the HasEdge predicate on the "monitors" edge.
+func HasMonitors() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, true, MonitorsTable, MonitorsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasMonitorsWith applies the HasEdge predicate on the "monitors" edge with a given conditions (other predicates).
+func HasMonitorsWith(preds ...predicate.ChannelMonitor) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(func(s *sql.Selector) {
+ step := newMonitorsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.NotPredicates(p))
+}
diff --git a/backend/ent/channelmonitorrequesttemplate_create.go b/backend/ent/channelmonitorrequesttemplate_create.go
new file mode 100644
index 00000000..1ba842cd
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate_create.go
@@ -0,0 +1,942 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+)
+
+// ChannelMonitorRequestTemplateCreate is the builder for creating a ChannelMonitorRequestTemplate entity.
+type ChannelMonitorRequestTemplateCreate struct {
+ config
+ mutation *ChannelMonitorRequestTemplateMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetCreatedAt(v time.Time) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableCreatedAt(v *time.Time) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableUpdatedAt(v *time.Time) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetName sets the "name" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetName(v string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetName(v)
+ return _c
+}
+
+// SetProvider sets the "provider" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetProvider(v)
+ return _c
+}
+
+// SetDescription sets the "description" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetDescription(v string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetDescription(v)
+ return _c
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableDescription(v *string) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetDescription(*v)
+ }
+ return _c
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetExtraHeaders(v)
+ return _c
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetBodyOverrideMode(v)
+ return _c
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetBodyOverrideMode(*v)
+ }
+ return _c
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetBodyOverride(v)
+ return _c
+}
+
+// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by IDs.
+func (_c *ChannelMonitorRequestTemplateCreate) AddMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.AddMonitorIDs(ids...)
+ return _c
+}
+
+// AddMonitors adds the "monitors" edges to the ChannelMonitor entity.
+func (_c *ChannelMonitorRequestTemplateCreate) AddMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddMonitorIDs(ids...)
+}
+
+// Mutation returns the ChannelMonitorRequestTemplateMutation object of the builder.
+func (_c *ChannelMonitorRequestTemplateCreate) Mutation() *ChannelMonitorRequestTemplateMutation {
+ return _c.mutation
+}
+
+// Save creates the ChannelMonitorRequestTemplate in the database.
+func (_c *ChannelMonitorRequestTemplateCreate) Save(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *ChannelMonitorRequestTemplateCreate) SaveX(ctx context.Context) *ChannelMonitorRequestTemplate {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorRequestTemplateCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorRequestTemplateCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *ChannelMonitorRequestTemplateCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := channelmonitorrequesttemplate.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := channelmonitorrequesttemplate.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Description(); !ok {
+ v := channelmonitorrequesttemplate.DefaultDescription
+ _c.mutation.SetDescription(v)
+ }
+ if _, ok := _c.mutation.ExtraHeaders(); !ok {
+ v := channelmonitorrequesttemplate.DefaultExtraHeaders
+ _c.mutation.SetExtraHeaders(v)
+ }
+ if _, ok := _c.mutation.BodyOverrideMode(); !ok {
+ v := channelmonitorrequesttemplate.DefaultBodyOverrideMode
+ _c.mutation.SetBodyOverrideMode(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *ChannelMonitorRequestTemplateCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.updated_at"`)}
+ }
+ if _, ok := _c.mutation.Name(); !ok {
+ return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.name"`)}
+ }
+ if v, ok := _c.mutation.Name(); ok {
+ if err := channelmonitorrequesttemplate.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.name": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Provider(); !ok {
+ return &ValidationError{Name: "provider", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.provider"`)}
+ }
+ if v, ok := _c.mutation.Provider(); ok {
+ if err := channelmonitorrequesttemplate.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
+ }
+ }
+ if v, ok := _c.mutation.Description(); ok {
+ if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
+ return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ExtraHeaders(); !ok {
+ return &ValidationError{Name: "extra_headers", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.extra_headers"`)}
+ }
+ if _, ok := _c.mutation.BodyOverrideMode(); !ok {
+ return &ValidationError{Name: "body_override_mode", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.body_override_mode"`)}
+ }
+ if v, ok := _c.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitorrequesttemplate.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_c *ChannelMonitorRequestTemplateCreate) sqlSave(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *ChannelMonitorRequestTemplateCreate) createSpec() (*ChannelMonitorRequestTemplate, *sqlgraph.CreateSpec) {
+ var (
+ _node = &ChannelMonitorRequestTemplate{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(channelmonitorrequesttemplate.Table, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.Name(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldName, field.TypeString, value)
+ _node.Name = value
+ }
+ if value, ok := _c.mutation.Provider(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
+ _node.Provider = value
+ }
+ if value, ok := _c.mutation.Description(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
+ _node.Description = value
+ }
+ if value, ok := _c.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldExtraHeaders, field.TypeJSON, value)
+ _node.ExtraHeaders = value
+ }
+ if value, ok := _c.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverrideMode, field.TypeString, value)
+ _node.BodyOverrideMode = value
+ }
+ if value, ok := _c.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON, value)
+ _node.BodyOverride = value
+ }
+ if nodes := _c.mutation.MonitorsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorRequestTemplateUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorRequestTemplateCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorRequestTemplateUpsertOne {
+ _c.conflict = opts
+ return &ChannelMonitorRequestTemplateUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorRequestTemplateCreate) OnConflictColumns(columns ...string) *ChannelMonitorRequestTemplateUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorRequestTemplateUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // ChannelMonitorRequestTemplateUpsertOne is the builder for "upsert"-ing
+ // one ChannelMonitorRequestTemplate node.
+ ChannelMonitorRequestTemplateUpsertOne struct {
+ create *ChannelMonitorRequestTemplateCreate
+ }
+
+ // ChannelMonitorRequestTemplateUpsert is the "OnConflict" setter.
+ ChannelMonitorRequestTemplateUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateUpdatedAt() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldUpdatedAt)
+ return u
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetName(v string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldName, v)
+ return u
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateName() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldName)
+ return u
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldProvider, v)
+ return u
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateProvider() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldProvider)
+ return u
+}
+
+// SetDescription sets the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetDescription(v string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldDescription, v)
+ return u
+}
+
+// UpdateDescription sets the "description" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateDescription() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldDescription)
+ return u
+}
+
+// ClearDescription clears the value of the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsert) ClearDescription() *ChannelMonitorRequestTemplateUpsert {
+ u.SetNull(channelmonitorrequesttemplate.FieldDescription)
+ return u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldExtraHeaders, v)
+ return u
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateExtraHeaders() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldExtraHeaders)
+ return u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldBodyOverrideMode, v)
+ return u
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateBodyOverrideMode() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldBodyOverrideMode)
+ return u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldBodyOverride, v)
+ return u
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateBodyOverride() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldBodyOverride)
+ return u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsert) ClearBodyOverride() *ChannelMonitorRequestTemplateUpsert {
+ u.SetNull(channelmonitorrequesttemplate.FieldBodyOverride)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateNewValues() *ChannelMonitorRequestTemplateUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(channelmonitorrequesttemplate.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorRequestTemplateUpsertOne) Ignore() *ChannelMonitorRequestTemplateUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorRequestTemplateUpsertOne) DoNothing() *ChannelMonitorRequestTemplateUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorRequestTemplateCreate.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorRequestTemplateUpsertOne) Update(set func(*ChannelMonitorRequestTemplateUpsert)) *ChannelMonitorRequestTemplateUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorRequestTemplateUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateUpdatedAt() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetName(v string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateName() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetProvider(v)
+ })
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateProvider() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateProvider()
+ })
+}
+
+// SetDescription sets the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetDescription(v)
+ })
+}
+
+// UpdateDescription sets the "description" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateDescription() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateDescription()
+ })
+}
+
+// ClearDescription clears the value of the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) ClearDescription() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.ClearDescription()
+ })
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetExtraHeaders(v)
+ })
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateExtraHeaders() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateExtraHeaders()
+ })
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetBodyOverrideMode(v)
+ })
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateBodyOverrideMode() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateBodyOverrideMode()
+ })
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetBodyOverride(v)
+ })
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateBodyOverride() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateBodyOverride()
+ })
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) ClearBodyOverride() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.ClearBodyOverride()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorRequestTemplateUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorRequestTemplateCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorRequestTemplateUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *ChannelMonitorRequestTemplateUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *ChannelMonitorRequestTemplateUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// ChannelMonitorRequestTemplateCreateBulk is the builder for creating many ChannelMonitorRequestTemplate entities in bulk.
+type ChannelMonitorRequestTemplateCreateBulk struct {
+ config
+ err error
+ builders []*ChannelMonitorRequestTemplateCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the ChannelMonitorRequestTemplate entities in the database.
+func (_c *ChannelMonitorRequestTemplateCreateBulk) Save(ctx context.Context) ([]*ChannelMonitorRequestTemplate, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*ChannelMonitorRequestTemplate, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*ChannelMonitorRequestTemplateMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *ChannelMonitorRequestTemplateCreateBulk) SaveX(ctx context.Context) []*ChannelMonitorRequestTemplate {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorRequestTemplateCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorRequestTemplateCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorRequestTemplate.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorRequestTemplateUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorRequestTemplateCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorRequestTemplateUpsertBulk {
+ _c.conflict = opts
+ return &ChannelMonitorRequestTemplateUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorRequestTemplateCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorRequestTemplateUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorRequestTemplateUpsertBulk{
+ create: _c,
+ }
+}
+
+// ChannelMonitorRequestTemplateUpsertBulk is the builder for "upsert"-ing
+// a bulk of ChannelMonitorRequestTemplate nodes.
+type ChannelMonitorRequestTemplateUpsertBulk struct {
+ create *ChannelMonitorRequestTemplateCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateNewValues() *ChannelMonitorRequestTemplateUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(channelmonitorrequesttemplate.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorRequestTemplateUpsertBulk) Ignore() *ChannelMonitorRequestTemplateUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) DoNothing() *ChannelMonitorRequestTemplateUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorRequestTemplateCreateBulk.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) Update(set func(*ChannelMonitorRequestTemplateUpsert)) *ChannelMonitorRequestTemplateUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorRequestTemplateUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateUpdatedAt() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetName(v string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateName() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetProvider(v)
+ })
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateProvider() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateProvider()
+ })
+}
+
+// SetDescription sets the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetDescription(v)
+ })
+}
+
+// UpdateDescription sets the "description" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateDescription() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateDescription()
+ })
+}
+
+// ClearDescription clears the value of the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) ClearDescription() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.ClearDescription()
+ })
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetExtraHeaders(v)
+ })
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateExtraHeaders() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateExtraHeaders()
+ })
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetBodyOverrideMode(v)
+ })
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateBodyOverrideMode() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateBodyOverrideMode()
+ })
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetBodyOverride(v)
+ })
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateBodyOverride() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateBodyOverride()
+ })
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) ClearBodyOverride() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.ClearBodyOverride()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorRequestTemplateCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorRequestTemplateCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitorrequesttemplate_delete.go b/backend/ent/channelmonitorrequesttemplate_delete.go
new file mode 100644
index 00000000..98d365c8
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorRequestTemplateDelete is the builder for deleting a ChannelMonitorRequestTemplate entity.
+type ChannelMonitorRequestTemplateDelete struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorRequestTemplateMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateDelete builder.
+func (_d *ChannelMonitorRequestTemplateDelete) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *ChannelMonitorRequestTemplateDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorRequestTemplateDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *ChannelMonitorRequestTemplateDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(channelmonitorrequesttemplate.Table, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// ChannelMonitorRequestTemplateDeleteOne is the builder for deleting a single ChannelMonitorRequestTemplate entity.
+type ChannelMonitorRequestTemplateDeleteOne struct {
+ _d *ChannelMonitorRequestTemplateDelete
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateDelete builder.
+func (_d *ChannelMonitorRequestTemplateDeleteOne) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *ChannelMonitorRequestTemplateDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{channelmonitorrequesttemplate.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorRequestTemplateDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitorrequesttemplate_query.go b/backend/ent/channelmonitorrequesttemplate_query.go
new file mode 100644
index 00000000..6491ea60
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate_query.go
@@ -0,0 +1,648 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorRequestTemplateQuery is the builder for querying ChannelMonitorRequestTemplate entities.
+type ChannelMonitorRequestTemplateQuery struct {
+ config
+ ctx *QueryContext
+ order []channelmonitorrequesttemplate.OrderOption
+ inters []Interceptor
+ predicates []predicate.ChannelMonitorRequestTemplate
+ withMonitors *ChannelMonitorQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the ChannelMonitorRequestTemplateQuery builder.
+func (_q *ChannelMonitorRequestTemplateQuery) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *ChannelMonitorRequestTemplateQuery) Limit(limit int) *ChannelMonitorRequestTemplateQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *ChannelMonitorRequestTemplateQuery) Offset(offset int) *ChannelMonitorRequestTemplateQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *ChannelMonitorRequestTemplateQuery) Unique(unique bool) *ChannelMonitorRequestTemplateQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *ChannelMonitorRequestTemplateQuery) Order(o ...channelmonitorrequesttemplate.OrderOption) *ChannelMonitorRequestTemplateQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryMonitors chains the current query on the "monitors" edge.
+func (_q *ChannelMonitorRequestTemplateQuery) QueryMonitors() *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID, selector),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, true, channelmonitorrequesttemplate.MonitorsTable, channelmonitorrequesttemplate.MonitorsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first ChannelMonitorRequestTemplate entity from the query.
+// Returns a *NotFoundError when no ChannelMonitorRequestTemplate was found.
+func (_q *ChannelMonitorRequestTemplateQuery) First(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{channelmonitorrequesttemplate.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) FirstX(ctx context.Context) *ChannelMonitorRequestTemplate {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first ChannelMonitorRequestTemplate ID from the query.
+// Returns a *NotFoundError when no ChannelMonitorRequestTemplate ID was found.
+func (_q *ChannelMonitorRequestTemplateQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{channelmonitorrequesttemplate.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single ChannelMonitorRequestTemplate entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one ChannelMonitorRequestTemplate entity is found.
+// Returns a *NotFoundError when no ChannelMonitorRequestTemplate entities are found.
+func (_q *ChannelMonitorRequestTemplateQuery) Only(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{channelmonitorrequesttemplate.Label}
+ default:
+ return nil, &NotSingularError{channelmonitorrequesttemplate.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) OnlyX(ctx context.Context) *ChannelMonitorRequestTemplate {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only ChannelMonitorRequestTemplate ID in the query.
+// Returns a *NotSingularError when more than one ChannelMonitorRequestTemplate ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *ChannelMonitorRequestTemplateQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{channelmonitorrequesttemplate.Label}
+ default:
+ err = &NotSingularError{channelmonitorrequesttemplate.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of ChannelMonitorRequestTemplates.
+func (_q *ChannelMonitorRequestTemplateQuery) All(ctx context.Context) ([]*ChannelMonitorRequestTemplate, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*ChannelMonitorRequestTemplate, *ChannelMonitorRequestTemplateQuery]()
+ return withInterceptors[[]*ChannelMonitorRequestTemplate](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) AllX(ctx context.Context) []*ChannelMonitorRequestTemplate {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of ChannelMonitorRequestTemplate IDs.
+func (_q *ChannelMonitorRequestTemplateQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(channelmonitorrequesttemplate.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *ChannelMonitorRequestTemplateQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorRequestTemplateQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *ChannelMonitorRequestTemplateQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the ChannelMonitorRequestTemplateQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *ChannelMonitorRequestTemplateQuery) Clone() *ChannelMonitorRequestTemplateQuery {
+ if _q == nil {
+ return nil
+ }
+ return &ChannelMonitorRequestTemplateQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]channelmonitorrequesttemplate.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.ChannelMonitorRequestTemplate{}, _q.predicates...),
+ withMonitors: _q.withMonitors.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithMonitors tells the query-builder to eager-load the nodes that are connected to
+// the "monitors" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorRequestTemplateQuery) WithMonitors(opts ...func(*ChannelMonitorQuery)) *ChannelMonitorRequestTemplateQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withMonitors = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.ChannelMonitorRequestTemplate.Query().
+// GroupBy(channelmonitorrequesttemplate.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorRequestTemplateQuery) GroupBy(field string, fields ...string) *ChannelMonitorRequestTemplateGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &ChannelMonitorRequestTemplateGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = channelmonitorrequesttemplate.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.ChannelMonitorRequestTemplate.Query().
+// Select(channelmonitorrequesttemplate.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorRequestTemplateQuery) Select(fields ...string) *ChannelMonitorRequestTemplateSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &ChannelMonitorRequestTemplateSelect{ChannelMonitorRequestTemplateQuery: _q}
+ sbuild.label = channelmonitorrequesttemplate.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a ChannelMonitorRequestTemplateSelect configured with the given aggregations.
+func (_q *ChannelMonitorRequestTemplateQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorRequestTemplateSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !channelmonitorrequesttemplate.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitorRequestTemplate, error) {
+ var (
+ nodes = []*ChannelMonitorRequestTemplate{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withMonitors != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*ChannelMonitorRequestTemplate).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &ChannelMonitorRequestTemplate{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withMonitors; query != nil {
+ if err := _q.loadMonitors(ctx, query, nodes,
+ func(n *ChannelMonitorRequestTemplate) { n.Edges.Monitors = []*ChannelMonitor{} },
+ func(n *ChannelMonitorRequestTemplate, e *ChannelMonitor) {
+ n.Edges.Monitors = append(n.Edges.Monitors, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) loadMonitors(ctx context.Context, query *ChannelMonitorQuery, nodes []*ChannelMonitorRequestTemplate, init func(*ChannelMonitorRequestTemplate), assign func(*ChannelMonitorRequestTemplate, *ChannelMonitor)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*ChannelMonitorRequestTemplate)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(channelmonitor.FieldTemplateID)
+ }
+ query.Where(predicate.ChannelMonitor(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(channelmonitorrequesttemplate.MonitorsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.TemplateID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "template_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "template_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.Columns, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorrequesttemplate.FieldID)
+ for i := range fields {
+ if fields[i] != channelmonitorrequesttemplate.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(channelmonitorrequesttemplate.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = channelmonitorrequesttemplate.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *ChannelMonitorRequestTemplateQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorRequestTemplateQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *ChannelMonitorRequestTemplateQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorRequestTemplateQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// ChannelMonitorRequestTemplateGroupBy is the group-by builder for ChannelMonitorRequestTemplate entities.
+type ChannelMonitorRequestTemplateGroupBy struct {
+ selector
+ build *ChannelMonitorRequestTemplateQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *ChannelMonitorRequestTemplateGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorRequestTemplateGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *ChannelMonitorRequestTemplateGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorRequestTemplateQuery, *ChannelMonitorRequestTemplateGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *ChannelMonitorRequestTemplateGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorRequestTemplateQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// ChannelMonitorRequestTemplateSelect is the builder for selecting fields of ChannelMonitorRequestTemplate entities.
+type ChannelMonitorRequestTemplateSelect struct {
+ *ChannelMonitorRequestTemplateQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *ChannelMonitorRequestTemplateSelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorRequestTemplateSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *ChannelMonitorRequestTemplateSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorRequestTemplateQuery, *ChannelMonitorRequestTemplateSelect](ctx, _s.ChannelMonitorRequestTemplateQuery, _s, _s.inters, v)
+}
+
+func (_s *ChannelMonitorRequestTemplateSelect) sqlScan(ctx context.Context, root *ChannelMonitorRequestTemplateQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/channelmonitorrequesttemplate_update.go b/backend/ent/channelmonitorrequesttemplate_update.go
new file mode 100644
index 00000000..8f55ba04
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate_update.go
@@ -0,0 +1,639 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorRequestTemplateUpdate is the builder for updating ChannelMonitorRequestTemplate entities.
+type ChannelMonitorRequestTemplateUpdate struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorRequestTemplateMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateUpdate builder.
+func (_u *ChannelMonitorRequestTemplateUpdate) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetName(v string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableName(v *string) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetProvider sets the "provider" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetProvider(v)
+ return _u
+}
+
+// SetNillableProvider sets the "provider" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableProvider(v *channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetProvider(*v)
+ }
+ return _u
+}
+
+// SetDescription sets the "description" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetDescription(v string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetDescription(v)
+ return _u
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableDescription(v *string) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetDescription(*v)
+ }
+ return _u
+}
+
+// ClearDescription clears the value of the "description" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) ClearDescription() *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.ClearDescription()
+ return _u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetExtraHeaders(v)
+ return _u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetBodyOverrideMode(v)
+ return _u
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetBodyOverrideMode(*v)
+ }
+ return _u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetBodyOverride(v)
+ return _u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) ClearBodyOverride() *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.ClearBodyOverride()
+ return _u
+}
+
+// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by IDs.
+func (_u *ChannelMonitorRequestTemplateUpdate) AddMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.AddMonitorIDs(ids...)
+ return _u
+}
+
+// AddMonitors adds the "monitors" edges to the ChannelMonitor entity.
+func (_u *ChannelMonitorRequestTemplateUpdate) AddMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddMonitorIDs(ids...)
+}
+
+// Mutation returns the ChannelMonitorRequestTemplateMutation object of the builder.
+func (_u *ChannelMonitorRequestTemplateUpdate) Mutation() *ChannelMonitorRequestTemplateMutation {
+ return _u.mutation
+}
+
+// ClearMonitors clears all "monitors" edges to the ChannelMonitor entity.
+func (_u *ChannelMonitorRequestTemplateUpdate) ClearMonitors() *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.ClearMonitors()
+ return _u
+}
+
+// RemoveMonitorIDs removes the "monitors" edge to ChannelMonitor entities by IDs.
+func (_u *ChannelMonitorRequestTemplateUpdate) RemoveMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.RemoveMonitorIDs(ids...)
+ return _u
+}
+
+// RemoveMonitors removes "monitors" edges to ChannelMonitor entities.
+func (_u *ChannelMonitorRequestTemplateUpdate) RemoveMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveMonitorIDs(ids...)
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *ChannelMonitorRequestTemplateUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorRequestTemplateUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *ChannelMonitorRequestTemplateUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorRequestTemplateUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorRequestTemplateUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := channelmonitorrequesttemplate.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorRequestTemplateUpdate) check() error {
+ if v, ok := _u.mutation.Name(); ok {
+ if err := channelmonitorrequesttemplate.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Provider(); ok {
+ if err := channelmonitorrequesttemplate.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Description(); ok {
+ if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
+ return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitorrequesttemplate.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorRequestTemplateUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.Columns, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Provider(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.Description(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
+ }
+ if _u.mutation.DescriptionCleared() {
+ _spec.ClearField(channelmonitorrequesttemplate.FieldDescription, field.TypeString)
+ }
+ if value, ok := _u.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldExtraHeaders, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverrideMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON, value)
+ }
+ if _u.mutation.BodyOverrideCleared() {
+ _spec.ClearField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON)
+ }
+ if _u.mutation.MonitorsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedMonitorsIDs(); len(nodes) > 0 && !_u.mutation.MonitorsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitorrequesttemplate.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// ChannelMonitorRequestTemplateUpdateOne is the builder for updating a single ChannelMonitorRequestTemplate entity.
+type ChannelMonitorRequestTemplateUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *ChannelMonitorRequestTemplateMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetName(v string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableName(v *string) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetProvider sets the "provider" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetProvider(v)
+ return _u
+}
+
+// SetNillableProvider sets the "provider" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableProvider(v *channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetProvider(*v)
+ }
+ return _u
+}
+
+// SetDescription sets the "description" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetDescription(v)
+ return _u
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableDescription(v *string) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetDescription(*v)
+ }
+ return _u
+}
+
+// ClearDescription clears the value of the "description" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) ClearDescription() *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.ClearDescription()
+ return _u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetExtraHeaders(v)
+ return _u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetBodyOverrideMode(v)
+ return _u
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableBodyOverrideMode(v *string) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetBodyOverrideMode(*v)
+ }
+ return _u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetBodyOverride(v)
+ return _u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) ClearBodyOverride() *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.ClearBodyOverride()
+ return _u
+}
+
+// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by IDs.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) AddMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.AddMonitorIDs(ids...)
+ return _u
+}
+
+// AddMonitors adds the "monitors" edges to the ChannelMonitor entity.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) AddMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddMonitorIDs(ids...)
+}
+
+// Mutation returns the ChannelMonitorRequestTemplateMutation object of the builder.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Mutation() *ChannelMonitorRequestTemplateMutation {
+ return _u.mutation
+}
+
+// ClearMonitors clears all "monitors" edges to the ChannelMonitor entity.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) ClearMonitors() *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.ClearMonitors()
+ return _u
+}
+
+// RemoveMonitorIDs removes the "monitors" edge to ChannelMonitor entities by IDs.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) RemoveMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.RemoveMonitorIDs(ids...)
+ return _u
+}
+
+// RemoveMonitors removes "monitors" edges to ChannelMonitor entities.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) RemoveMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveMonitorIDs(ids...)
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateUpdate builder.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Select(field string, fields ...string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Save(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SaveX(ctx context.Context) *ChannelMonitorRequestTemplate {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := channelmonitorrequesttemplate.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) check() error {
+ if v, ok := _u.mutation.Name(); ok {
+ if err := channelmonitorrequesttemplate.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Provider(); ok {
+ if err := channelmonitorrequesttemplate.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Description(); ok {
+ if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
+ return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitorrequesttemplate.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorRequestTemplateUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitorRequestTemplate, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.Columns, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitorRequestTemplate.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorrequesttemplate.FieldID)
+ for _, f := range fields {
+ if !channelmonitorrequesttemplate.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != channelmonitorrequesttemplate.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Provider(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.Description(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
+ }
+ if _u.mutation.DescriptionCleared() {
+ _spec.ClearField(channelmonitorrequesttemplate.FieldDescription, field.TypeString)
+ }
+ if value, ok := _u.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldExtraHeaders, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverrideMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON, value)
+ }
+ if _u.mutation.BodyOverrideCleared() {
+ _spec.ClearField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON)
+ }
+ if _u.mutation.MonitorsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedMonitorsIDs(); len(nodes) > 0 && !_u.mutation.MonitorsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &ChannelMonitorRequestTemplate{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitorrequesttemplate.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/client.go b/backend/ent/client.go
index e52e015a..df20ddfa 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -20,12 +20,20 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -60,18 +68,34 @@ type Client struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
+ // AuthIdentity is the client for interacting with the AuthIdentity builders.
+ AuthIdentity *AuthIdentityClient
+ // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
+ AuthIdentityChannel *AuthIdentityChannelClient
+ // ChannelMonitor is the client for interacting with the ChannelMonitor builders.
+ ChannelMonitor *ChannelMonitorClient
+ // ChannelMonitorDailyRollup is the client for interacting with the ChannelMonitorDailyRollup builders.
+ ChannelMonitorDailyRollup *ChannelMonitorDailyRollupClient
+ // ChannelMonitorHistory is the client for interacting with the ChannelMonitorHistory builders.
+ ChannelMonitorHistory *ChannelMonitorHistoryClient
+ // ChannelMonitorRequestTemplate is the client for interacting with the ChannelMonitorRequestTemplate builders.
+ ChannelMonitorRequestTemplate *ChannelMonitorRequestTemplateClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
+ // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
+ IdentityAdoptionDecision *IdentityAdoptionDecisionClient
// PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
PaymentAuditLog *PaymentAuditLogClient
// PaymentOrder is the client for interacting with the PaymentOrder builders.
PaymentOrder *PaymentOrderClient
// PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
PaymentProviderInstance *PaymentProviderInstanceClient
+ // PendingAuthSession is the client for interacting with the PendingAuthSession builders.
+ PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@@ -118,12 +142,20 @@ func (c *Client) init() {
c.AccountGroup = NewAccountGroupClient(c.config)
c.Announcement = NewAnnouncementClient(c.config)
c.AnnouncementRead = NewAnnouncementReadClient(c.config)
+ c.AuthIdentity = NewAuthIdentityClient(c.config)
+ c.AuthIdentityChannel = NewAuthIdentityChannelClient(c.config)
+ c.ChannelMonitor = NewChannelMonitorClient(c.config)
+ c.ChannelMonitorDailyRollup = NewChannelMonitorDailyRollupClient(c.config)
+ c.ChannelMonitorHistory = NewChannelMonitorHistoryClient(c.config)
+ c.ChannelMonitorRequestTemplate = NewChannelMonitorRequestTemplateClient(c.config)
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
c.Group = NewGroupClient(c.config)
c.IdempotencyRecord = NewIdempotencyRecordClient(c.config)
+ c.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(c.config)
c.PaymentAuditLog = NewPaymentAuditLogClient(c.config)
c.PaymentOrder = NewPaymentOrderClient(c.config)
c.PaymentProviderInstance = NewPaymentProviderInstanceClient(c.config)
+ c.PendingAuthSession = NewPendingAuthSessionClient(c.config)
c.PromoCode = NewPromoCodeClient(c.config)
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
c.Proxy = NewProxyClient(c.config)
@@ -229,34 +261,42 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
cfg := c.config
cfg.driver = tx
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- PaymentAuditLog: NewPaymentAuditLogClient(cfg),
- PaymentOrder: NewPaymentOrderClient(cfg),
- PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- SubscriptionPlan: NewSubscriptionPlanClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ChannelMonitor: NewChannelMonitorClient(cfg),
+ ChannelMonitorDailyRollup: NewChannelMonitorDailyRollupClient(cfg),
+ ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg),
+ ChannelMonitorRequestTemplate: NewChannelMonitorRequestTemplateClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -274,34 +314,42 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
cfg := c.config
cfg.driver = &txDriver{tx: tx, drv: c.driver}
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- PaymentAuditLog: NewPaymentAuditLogClient(cfg),
- PaymentOrder: NewPaymentOrderClient(cfg),
- PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- SubscriptionPlan: NewSubscriptionPlanClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ChannelMonitor: NewChannelMonitorClient(cfg),
+ ChannelMonitorDailyRollup: NewChannelMonitorDailyRollupClient(cfg),
+ ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg),
+ ChannelMonitorRequestTemplate: NewChannelMonitorRequestTemplateClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -332,11 +380,14 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
- c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
- c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.AuthIdentity, c.AuthIdentityChannel, c.ChannelMonitor,
+ c.ChannelMonitorDailyRollup, c.ChannelMonitorHistory,
+ c.ChannelMonitorRequestTemplate, c.ErrorPassthroughRule, c.Group,
+ c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Use(hooks...)
@@ -348,11 +399,14 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
- c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog,
- c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.AuthIdentity, c.AuthIdentityChannel, c.ChannelMonitor,
+ c.ChannelMonitorDailyRollup, c.ChannelMonitorHistory,
+ c.ChannelMonitorRequestTemplate, c.ErrorPassthroughRule, c.Group,
+ c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Intercept(interceptors...)
@@ -372,18 +426,34 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.Announcement.mutate(ctx, m)
case *AnnouncementReadMutation:
return c.AnnouncementRead.mutate(ctx, m)
+ case *AuthIdentityMutation:
+ return c.AuthIdentity.mutate(ctx, m)
+ case *AuthIdentityChannelMutation:
+ return c.AuthIdentityChannel.mutate(ctx, m)
+ case *ChannelMonitorMutation:
+ return c.ChannelMonitor.mutate(ctx, m)
+ case *ChannelMonitorDailyRollupMutation:
+ return c.ChannelMonitorDailyRollup.mutate(ctx, m)
+ case *ChannelMonitorHistoryMutation:
+ return c.ChannelMonitorHistory.mutate(ctx, m)
+ case *ChannelMonitorRequestTemplateMutation:
+ return c.ChannelMonitorRequestTemplate.mutate(ctx, m)
case *ErrorPassthroughRuleMutation:
return c.ErrorPassthroughRule.mutate(ctx, m)
case *GroupMutation:
return c.Group.mutate(ctx, m)
case *IdempotencyRecordMutation:
return c.IdempotencyRecord.mutate(ctx, m)
+ case *IdentityAdoptionDecisionMutation:
+ return c.IdentityAdoptionDecision.mutate(ctx, m)
case *PaymentAuditLogMutation:
return c.PaymentAuditLog.mutate(ctx, m)
case *PaymentOrderMutation:
return c.PaymentOrder.mutate(ctx, m)
case *PaymentProviderInstanceMutation:
return c.PaymentProviderInstance.mutate(ctx, m)
+ case *PendingAuthSessionMutation:
+ return c.PendingAuthSession.mutate(ctx, m)
case *PromoCodeMutation:
return c.PromoCode.mutate(ctx, m)
case *PromoCodeUsageMutation:
@@ -1231,6 +1301,964 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead
}
}
+// AuthIdentityClient is a client for the AuthIdentity schema.
+type AuthIdentityClient struct {
+ config
+}
+
+// NewAuthIdentityClient returns a client for the AuthIdentity from the given config.
+func NewAuthIdentityClient(c config) *AuthIdentityClient {
+ return &AuthIdentityClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `authidentity.Hooks(f(g(h())))`.
+func (c *AuthIdentityClient) Use(hooks ...Hook) {
+ c.hooks.AuthIdentity = append(c.hooks.AuthIdentity, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `authidentity.Intercept(f(g(h())))`.
+func (c *AuthIdentityClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AuthIdentity = append(c.inters.AuthIdentity, interceptors...)
+}
+
+// Create returns a builder for creating a AuthIdentity entity.
+func (c *AuthIdentityClient) Create() *AuthIdentityCreate {
+ mutation := newAuthIdentityMutation(c.config, OpCreate)
+ return &AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AuthIdentity entities.
+func (c *AuthIdentityClient) CreateBulk(builders ...*AuthIdentityCreate) *AuthIdentityCreateBulk {
+ return &AuthIdentityCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AuthIdentityClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityCreate, int)) *AuthIdentityCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AuthIdentityCreateBulk{err: fmt.Errorf("calling to AuthIdentityClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AuthIdentityCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AuthIdentityCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AuthIdentity.
+func (c *AuthIdentityClient) Update() *AuthIdentityUpdate {
+ mutation := newAuthIdentityMutation(c.config, OpUpdate)
+ return &AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AuthIdentityClient) UpdateOne(_m *AuthIdentity) *AuthIdentityUpdateOne {
+ mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentity(_m))
+ return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AuthIdentityClient) UpdateOneID(id int64) *AuthIdentityUpdateOne {
+ mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentityID(id))
+ return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AuthIdentity.
+func (c *AuthIdentityClient) Delete() *AuthIdentityDelete {
+ mutation := newAuthIdentityMutation(c.config, OpDelete)
+ return &AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AuthIdentityClient) DeleteOne(_m *AuthIdentity) *AuthIdentityDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AuthIdentityClient) DeleteOneID(id int64) *AuthIdentityDeleteOne {
+ builder := c.Delete().Where(authidentity.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AuthIdentityDeleteOne{builder}
+}
+
+// Query returns a query builder for AuthIdentity.
+func (c *AuthIdentityClient) Query() *AuthIdentityQuery {
+ return &AuthIdentityQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAuthIdentity},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AuthIdentity entity by its id.
+func (c *AuthIdentityClient) Get(ctx context.Context, id int64) (*AuthIdentity, error) {
+ return c.Query().Where(authidentity.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AuthIdentityClient) GetX(ctx context.Context, id int64) *AuthIdentity {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryUser queries the user edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryUser(_m *AuthIdentity) *UserQuery {
+ query := (&UserClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryChannels queries the channels edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryChannels(_m *AuthIdentity) *AuthIdentityChannelQuery {
+ query := (&AuthIdentityChannelClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecisions queries the adoption_decisions edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryAdoptionDecisions(_m *AuthIdentity) *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AuthIdentityClient) Hooks() []Hook {
+ return c.hooks.AuthIdentity
+}
+
+// Interceptors returns the client interceptors.
+func (c *AuthIdentityClient) Interceptors() []Interceptor {
+ return c.inters.AuthIdentity
+}
+
+func (c *AuthIdentityClient) mutate(ctx context.Context, m *AuthIdentityMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AuthIdentity mutation op: %q", m.Op())
+ }
+}
+
+// AuthIdentityChannelClient is a client for the AuthIdentityChannel schema.
+type AuthIdentityChannelClient struct {
+ config
+}
+
+// NewAuthIdentityChannelClient returns a client for the AuthIdentityChannel from the given config.
+func NewAuthIdentityChannelClient(c config) *AuthIdentityChannelClient {
+ return &AuthIdentityChannelClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `authidentitychannel.Hooks(f(g(h())))`.
+func (c *AuthIdentityChannelClient) Use(hooks ...Hook) {
+ c.hooks.AuthIdentityChannel = append(c.hooks.AuthIdentityChannel, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `authidentitychannel.Intercept(f(g(h())))`.
+func (c *AuthIdentityChannelClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AuthIdentityChannel = append(c.inters.AuthIdentityChannel, interceptors...)
+}
+
+// Create returns a builder for creating a AuthIdentityChannel entity.
+func (c *AuthIdentityChannelClient) Create() *AuthIdentityChannelCreate {
+ mutation := newAuthIdentityChannelMutation(c.config, OpCreate)
+ return &AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AuthIdentityChannel entities.
+func (c *AuthIdentityChannelClient) CreateBulk(builders ...*AuthIdentityChannelCreate) *AuthIdentityChannelCreateBulk {
+ return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AuthIdentityChannelClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityChannelCreate, int)) *AuthIdentityChannelCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AuthIdentityChannelCreateBulk{err: fmt.Errorf("calling to AuthIdentityChannelClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AuthIdentityChannelCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Update() *AuthIdentityChannelUpdate {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdate)
+ return &AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AuthIdentityChannelClient) UpdateOne(_m *AuthIdentityChannel) *AuthIdentityChannelUpdateOne {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannel(_m))
+ return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AuthIdentityChannelClient) UpdateOneID(id int64) *AuthIdentityChannelUpdateOne {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannelID(id))
+ return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Delete() *AuthIdentityChannelDelete {
+ mutation := newAuthIdentityChannelMutation(c.config, OpDelete)
+ return &AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AuthIdentityChannelClient) DeleteOne(_m *AuthIdentityChannel) *AuthIdentityChannelDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AuthIdentityChannelClient) DeleteOneID(id int64) *AuthIdentityChannelDeleteOne {
+ builder := c.Delete().Where(authidentitychannel.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AuthIdentityChannelDeleteOne{builder}
+}
+
+// Query returns a query builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Query() *AuthIdentityChannelQuery {
+ return &AuthIdentityChannelQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAuthIdentityChannel},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AuthIdentityChannel entity by its id.
+func (c *AuthIdentityChannelClient) Get(ctx context.Context, id int64) (*AuthIdentityChannel, error) {
+ return c.Query().Where(authidentitychannel.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AuthIdentityChannelClient) GetX(ctx context.Context, id int64) *AuthIdentityChannel {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryIdentity queries the identity edge of a AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) QueryIdentity(_m *AuthIdentityChannel) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AuthIdentityChannelClient) Hooks() []Hook {
+ return c.hooks.AuthIdentityChannel
+}
+
+// Interceptors returns the client interceptors.
+func (c *AuthIdentityChannelClient) Interceptors() []Interceptor {
+ return c.inters.AuthIdentityChannel
+}
+
+func (c *AuthIdentityChannelClient) mutate(ctx context.Context, m *AuthIdentityChannelMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AuthIdentityChannel mutation op: %q", m.Op())
+ }
+}
+
+// ChannelMonitorClient is a client for the ChannelMonitor schema.
+type ChannelMonitorClient struct {
+ config
+}
+
+// NewChannelMonitorClient returns a client for the ChannelMonitor from the given config.
+func NewChannelMonitorClient(c config) *ChannelMonitorClient {
+ return &ChannelMonitorClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `channelmonitor.Hooks(f(g(h())))`.
+func (c *ChannelMonitorClient) Use(hooks ...Hook) {
+ c.hooks.ChannelMonitor = append(c.hooks.ChannelMonitor, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `channelmonitor.Intercept(f(g(h())))`.
+func (c *ChannelMonitorClient) Intercept(interceptors ...Interceptor) {
+ c.inters.ChannelMonitor = append(c.inters.ChannelMonitor, interceptors...)
+}
+
+// Create returns a builder for creating a ChannelMonitor entity.
+func (c *ChannelMonitorClient) Create() *ChannelMonitorCreate {
+ mutation := newChannelMonitorMutation(c.config, OpCreate)
+ return &ChannelMonitorCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of ChannelMonitor entities.
+func (c *ChannelMonitorClient) CreateBulk(builders ...*ChannelMonitorCreate) *ChannelMonitorCreateBulk {
+ return &ChannelMonitorCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *ChannelMonitorClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorCreate, int)) *ChannelMonitorCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &ChannelMonitorCreateBulk{err: fmt.Errorf("calling to ChannelMonitorClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*ChannelMonitorCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &ChannelMonitorCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for ChannelMonitor.
+func (c *ChannelMonitorClient) Update() *ChannelMonitorUpdate {
+ mutation := newChannelMonitorMutation(c.config, OpUpdate)
+ return &ChannelMonitorUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *ChannelMonitorClient) UpdateOne(_m *ChannelMonitor) *ChannelMonitorUpdateOne {
+ mutation := newChannelMonitorMutation(c.config, OpUpdateOne, withChannelMonitor(_m))
+ return &ChannelMonitorUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *ChannelMonitorClient) UpdateOneID(id int64) *ChannelMonitorUpdateOne {
+ mutation := newChannelMonitorMutation(c.config, OpUpdateOne, withChannelMonitorID(id))
+ return &ChannelMonitorUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for ChannelMonitor.
+func (c *ChannelMonitorClient) Delete() *ChannelMonitorDelete {
+ mutation := newChannelMonitorMutation(c.config, OpDelete)
+ return &ChannelMonitorDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *ChannelMonitorClient) DeleteOne(_m *ChannelMonitor) *ChannelMonitorDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *ChannelMonitorClient) DeleteOneID(id int64) *ChannelMonitorDeleteOne {
+ builder := c.Delete().Where(channelmonitor.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &ChannelMonitorDeleteOne{builder}
+}
+
+// Query returns a query builder for ChannelMonitor.
+func (c *ChannelMonitorClient) Query() *ChannelMonitorQuery {
+ return &ChannelMonitorQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeChannelMonitor},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a ChannelMonitor entity by its id.
+func (c *ChannelMonitorClient) Get(ctx context.Context, id int64) (*ChannelMonitor, error) {
+ return c.Query().Where(channelmonitor.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *ChannelMonitorClient) GetX(ctx context.Context, id int64) *ChannelMonitor {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryHistory queries the history edge of a ChannelMonitor.
+func (c *ChannelMonitorClient) QueryHistory(_m *ChannelMonitor) *ChannelMonitorHistoryQuery {
+ query := (&ChannelMonitorHistoryClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, id),
+ sqlgraph.To(channelmonitorhistory.Table, channelmonitorhistory.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.HistoryTable, channelmonitor.HistoryColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryDailyRollups queries the daily_rollups edge of a ChannelMonitor.
+func (c *ChannelMonitorClient) QueryDailyRollups(_m *ChannelMonitor) *ChannelMonitorDailyRollupQuery {
+ query := (&ChannelMonitorDailyRollupClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, id),
+ sqlgraph.To(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.DailyRollupsTable, channelmonitor.DailyRollupsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryRequestTemplate queries the request_template edge of a ChannelMonitor.
+func (c *ChannelMonitorClient) QueryRequestTemplate(_m *ChannelMonitor) *ChannelMonitorRequestTemplateQuery {
+ query := (&ChannelMonitorRequestTemplateClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, id),
+ sqlgraph.To(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, false, channelmonitor.RequestTemplateTable, channelmonitor.RequestTemplateColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *ChannelMonitorClient) Hooks() []Hook {
+ return c.hooks.ChannelMonitor
+}
+
+// Interceptors returns the client interceptors.
+func (c *ChannelMonitorClient) Interceptors() []Interceptor {
+ return c.inters.ChannelMonitor
+}
+
+func (c *ChannelMonitorClient) mutate(ctx context.Context, m *ChannelMonitorMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&ChannelMonitorCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&ChannelMonitorUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&ChannelMonitorUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&ChannelMonitorDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown ChannelMonitor mutation op: %q", m.Op())
+ }
+}
+
+// ChannelMonitorDailyRollupClient is a client for the ChannelMonitorDailyRollup schema.
+type ChannelMonitorDailyRollupClient struct {
+ config
+}
+
+// NewChannelMonitorDailyRollupClient returns a client for the ChannelMonitorDailyRollup from the given config.
+func NewChannelMonitorDailyRollupClient(c config) *ChannelMonitorDailyRollupClient {
+ return &ChannelMonitorDailyRollupClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `channelmonitordailyrollup.Hooks(f(g(h())))`.
+func (c *ChannelMonitorDailyRollupClient) Use(hooks ...Hook) {
+ c.hooks.ChannelMonitorDailyRollup = append(c.hooks.ChannelMonitorDailyRollup, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `channelmonitordailyrollup.Intercept(f(g(h())))`.
+func (c *ChannelMonitorDailyRollupClient) Intercept(interceptors ...Interceptor) {
+ c.inters.ChannelMonitorDailyRollup = append(c.inters.ChannelMonitorDailyRollup, interceptors...)
+}
+
+// Create returns a builder for creating a ChannelMonitorDailyRollup entity.
+func (c *ChannelMonitorDailyRollupClient) Create() *ChannelMonitorDailyRollupCreate {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpCreate)
+ return &ChannelMonitorDailyRollupCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of ChannelMonitorDailyRollup entities.
+func (c *ChannelMonitorDailyRollupClient) CreateBulk(builders ...*ChannelMonitorDailyRollupCreate) *ChannelMonitorDailyRollupCreateBulk {
+ return &ChannelMonitorDailyRollupCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *ChannelMonitorDailyRollupClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorDailyRollupCreate, int)) *ChannelMonitorDailyRollupCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &ChannelMonitorDailyRollupCreateBulk{err: fmt.Errorf("calling to ChannelMonitorDailyRollupClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*ChannelMonitorDailyRollupCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &ChannelMonitorDailyRollupCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for ChannelMonitorDailyRollup.
+func (c *ChannelMonitorDailyRollupClient) Update() *ChannelMonitorDailyRollupUpdate {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpUpdate)
+ return &ChannelMonitorDailyRollupUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *ChannelMonitorDailyRollupClient) UpdateOne(_m *ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupUpdateOne {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpUpdateOne, withChannelMonitorDailyRollup(_m))
+ return &ChannelMonitorDailyRollupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *ChannelMonitorDailyRollupClient) UpdateOneID(id int64) *ChannelMonitorDailyRollupUpdateOne {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpUpdateOne, withChannelMonitorDailyRollupID(id))
+ return &ChannelMonitorDailyRollupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for ChannelMonitorDailyRollup.
+func (c *ChannelMonitorDailyRollupClient) Delete() *ChannelMonitorDailyRollupDelete {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpDelete)
+ return &ChannelMonitorDailyRollupDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *ChannelMonitorDailyRollupClient) DeleteOne(_m *ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *ChannelMonitorDailyRollupClient) DeleteOneID(id int64) *ChannelMonitorDailyRollupDeleteOne {
+ builder := c.Delete().Where(channelmonitordailyrollup.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &ChannelMonitorDailyRollupDeleteOne{builder}
+}
+
+// Query returns a query builder for ChannelMonitorDailyRollup.
+func (c *ChannelMonitorDailyRollupClient) Query() *ChannelMonitorDailyRollupQuery {
+ return &ChannelMonitorDailyRollupQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeChannelMonitorDailyRollup},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a ChannelMonitorDailyRollup entity by its id.
+func (c *ChannelMonitorDailyRollupClient) Get(ctx context.Context, id int64) (*ChannelMonitorDailyRollup, error) {
+ return c.Query().Where(channelmonitordailyrollup.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *ChannelMonitorDailyRollupClient) GetX(ctx context.Context, id int64) *ChannelMonitorDailyRollup {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryMonitor queries the monitor edge of a ChannelMonitorDailyRollup.
+func (c *ChannelMonitorDailyRollupClient) QueryMonitor(_m *ChannelMonitorDailyRollup) *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID, id),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, channelmonitordailyrollup.MonitorTable, channelmonitordailyrollup.MonitorColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *ChannelMonitorDailyRollupClient) Hooks() []Hook {
+ return c.hooks.ChannelMonitorDailyRollup
+}
+
+// Interceptors returns the client interceptors.
+func (c *ChannelMonitorDailyRollupClient) Interceptors() []Interceptor {
+ return c.inters.ChannelMonitorDailyRollup
+}
+
+func (c *ChannelMonitorDailyRollupClient) mutate(ctx context.Context, m *ChannelMonitorDailyRollupMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&ChannelMonitorDailyRollupCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&ChannelMonitorDailyRollupUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&ChannelMonitorDailyRollupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&ChannelMonitorDailyRollupDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown ChannelMonitorDailyRollup mutation op: %q", m.Op())
+ }
+}
+
+// ChannelMonitorHistoryClient is a client for the ChannelMonitorHistory schema.
+type ChannelMonitorHistoryClient struct {
+ config
+}
+
+// NewChannelMonitorHistoryClient returns a client for the ChannelMonitorHistory from the given config.
+func NewChannelMonitorHistoryClient(c config) *ChannelMonitorHistoryClient {
+ return &ChannelMonitorHistoryClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `channelmonitorhistory.Hooks(f(g(h())))`.
+func (c *ChannelMonitorHistoryClient) Use(hooks ...Hook) {
+ c.hooks.ChannelMonitorHistory = append(c.hooks.ChannelMonitorHistory, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `channelmonitorhistory.Intercept(f(g(h())))`.
+func (c *ChannelMonitorHistoryClient) Intercept(interceptors ...Interceptor) {
+ c.inters.ChannelMonitorHistory = append(c.inters.ChannelMonitorHistory, interceptors...)
+}
+
+// Create returns a builder for creating a ChannelMonitorHistory entity.
+func (c *ChannelMonitorHistoryClient) Create() *ChannelMonitorHistoryCreate {
+ mutation := newChannelMonitorHistoryMutation(c.config, OpCreate)
+ return &ChannelMonitorHistoryCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of ChannelMonitorHistory entities.
+func (c *ChannelMonitorHistoryClient) CreateBulk(builders ...*ChannelMonitorHistoryCreate) *ChannelMonitorHistoryCreateBulk {
+ return &ChannelMonitorHistoryCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *ChannelMonitorHistoryClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorHistoryCreate, int)) *ChannelMonitorHistoryCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &ChannelMonitorHistoryCreateBulk{err: fmt.Errorf("calling to ChannelMonitorHistoryClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*ChannelMonitorHistoryCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &ChannelMonitorHistoryCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for ChannelMonitorHistory.
+func (c *ChannelMonitorHistoryClient) Update() *ChannelMonitorHistoryUpdate {
+ mutation := newChannelMonitorHistoryMutation(c.config, OpUpdate)
+ return &ChannelMonitorHistoryUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *ChannelMonitorHistoryClient) UpdateOne(_m *ChannelMonitorHistory) *ChannelMonitorHistoryUpdateOne {
+ mutation := newChannelMonitorHistoryMutation(c.config, OpUpdateOne, withChannelMonitorHistory(_m))
+ return &ChannelMonitorHistoryUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *ChannelMonitorHistoryClient) UpdateOneID(id int64) *ChannelMonitorHistoryUpdateOne {
+ mutation := newChannelMonitorHistoryMutation(c.config, OpUpdateOne, withChannelMonitorHistoryID(id))
+ return &ChannelMonitorHistoryUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for ChannelMonitorHistory.
+func (c *ChannelMonitorHistoryClient) Delete() *ChannelMonitorHistoryDelete {
+ mutation := newChannelMonitorHistoryMutation(c.config, OpDelete)
+ return &ChannelMonitorHistoryDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *ChannelMonitorHistoryClient) DeleteOne(_m *ChannelMonitorHistory) *ChannelMonitorHistoryDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *ChannelMonitorHistoryClient) DeleteOneID(id int64) *ChannelMonitorHistoryDeleteOne {
+ builder := c.Delete().Where(channelmonitorhistory.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &ChannelMonitorHistoryDeleteOne{builder}
+}
+
+// Query returns a query builder for ChannelMonitorHistory.
+func (c *ChannelMonitorHistoryClient) Query() *ChannelMonitorHistoryQuery {
+ return &ChannelMonitorHistoryQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeChannelMonitorHistory},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a ChannelMonitorHistory entity by its id.
+func (c *ChannelMonitorHistoryClient) Get(ctx context.Context, id int64) (*ChannelMonitorHistory, error) {
+ return c.Query().Where(channelmonitorhistory.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *ChannelMonitorHistoryClient) GetX(ctx context.Context, id int64) *ChannelMonitorHistory {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryMonitor queries the monitor edge of a ChannelMonitorHistory.
+func (c *ChannelMonitorHistoryClient) QueryMonitor(_m *ChannelMonitorHistory) *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitorhistory.Table, channelmonitorhistory.FieldID, id),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, channelmonitorhistory.MonitorTable, channelmonitorhistory.MonitorColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *ChannelMonitorHistoryClient) Hooks() []Hook {
+ return c.hooks.ChannelMonitorHistory
+}
+
+// Interceptors returns the client interceptors.
+func (c *ChannelMonitorHistoryClient) Interceptors() []Interceptor {
+ return c.inters.ChannelMonitorHistory
+}
+
+func (c *ChannelMonitorHistoryClient) mutate(ctx context.Context, m *ChannelMonitorHistoryMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&ChannelMonitorHistoryCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&ChannelMonitorHistoryUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&ChannelMonitorHistoryUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&ChannelMonitorHistoryDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown ChannelMonitorHistory mutation op: %q", m.Op())
+ }
+}
+
+// ChannelMonitorRequestTemplateClient is a client for the ChannelMonitorRequestTemplate schema.
+type ChannelMonitorRequestTemplateClient struct {
+ config
+}
+
+// NewChannelMonitorRequestTemplateClient returns a client for the ChannelMonitorRequestTemplate from the given config.
+func NewChannelMonitorRequestTemplateClient(c config) *ChannelMonitorRequestTemplateClient {
+ return &ChannelMonitorRequestTemplateClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `channelmonitorrequesttemplate.Hooks(f(g(h())))`.
+func (c *ChannelMonitorRequestTemplateClient) Use(hooks ...Hook) {
+ c.hooks.ChannelMonitorRequestTemplate = append(c.hooks.ChannelMonitorRequestTemplate, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `channelmonitorrequesttemplate.Intercept(f(g(h())))`.
+func (c *ChannelMonitorRequestTemplateClient) Intercept(interceptors ...Interceptor) {
+ c.inters.ChannelMonitorRequestTemplate = append(c.inters.ChannelMonitorRequestTemplate, interceptors...)
+}
+
+// Create returns a builder for creating a ChannelMonitorRequestTemplate entity.
+func (c *ChannelMonitorRequestTemplateClient) Create() *ChannelMonitorRequestTemplateCreate {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpCreate)
+ return &ChannelMonitorRequestTemplateCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of ChannelMonitorRequestTemplate entities.
+func (c *ChannelMonitorRequestTemplateClient) CreateBulk(builders ...*ChannelMonitorRequestTemplateCreate) *ChannelMonitorRequestTemplateCreateBulk {
+ return &ChannelMonitorRequestTemplateCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *ChannelMonitorRequestTemplateClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorRequestTemplateCreate, int)) *ChannelMonitorRequestTemplateCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &ChannelMonitorRequestTemplateCreateBulk{err: fmt.Errorf("calling to ChannelMonitorRequestTemplateClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*ChannelMonitorRequestTemplateCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &ChannelMonitorRequestTemplateCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for ChannelMonitorRequestTemplate.
+func (c *ChannelMonitorRequestTemplateClient) Update() *ChannelMonitorRequestTemplateUpdate {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpUpdate)
+ return &ChannelMonitorRequestTemplateUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *ChannelMonitorRequestTemplateClient) UpdateOne(_m *ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateUpdateOne {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpUpdateOne, withChannelMonitorRequestTemplate(_m))
+ return &ChannelMonitorRequestTemplateUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *ChannelMonitorRequestTemplateClient) UpdateOneID(id int64) *ChannelMonitorRequestTemplateUpdateOne {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpUpdateOne, withChannelMonitorRequestTemplateID(id))
+ return &ChannelMonitorRequestTemplateUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for ChannelMonitorRequestTemplate.
+func (c *ChannelMonitorRequestTemplateClient) Delete() *ChannelMonitorRequestTemplateDelete {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpDelete)
+ return &ChannelMonitorRequestTemplateDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *ChannelMonitorRequestTemplateClient) DeleteOne(_m *ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *ChannelMonitorRequestTemplateClient) DeleteOneID(id int64) *ChannelMonitorRequestTemplateDeleteOne {
+ builder := c.Delete().Where(channelmonitorrequesttemplate.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &ChannelMonitorRequestTemplateDeleteOne{builder}
+}
+
+// Query returns a query builder for ChannelMonitorRequestTemplate.
+func (c *ChannelMonitorRequestTemplateClient) Query() *ChannelMonitorRequestTemplateQuery {
+ return &ChannelMonitorRequestTemplateQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeChannelMonitorRequestTemplate},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a ChannelMonitorRequestTemplate entity by its id.
+func (c *ChannelMonitorRequestTemplateClient) Get(ctx context.Context, id int64) (*ChannelMonitorRequestTemplate, error) {
+ return c.Query().Where(channelmonitorrequesttemplate.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *ChannelMonitorRequestTemplateClient) GetX(ctx context.Context, id int64) *ChannelMonitorRequestTemplate {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryMonitors queries the monitors edge of a ChannelMonitorRequestTemplate.
+func (c *ChannelMonitorRequestTemplateClient) QueryMonitors(_m *ChannelMonitorRequestTemplate) *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID, id),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, true, channelmonitorrequesttemplate.MonitorsTable, channelmonitorrequesttemplate.MonitorsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *ChannelMonitorRequestTemplateClient) Hooks() []Hook {
+ return c.hooks.ChannelMonitorRequestTemplate
+}
+
+// Interceptors returns the client interceptors.
+func (c *ChannelMonitorRequestTemplateClient) Interceptors() []Interceptor {
+ return c.inters.ChannelMonitorRequestTemplate
+}
+
+func (c *ChannelMonitorRequestTemplateClient) mutate(ctx context.Context, m *ChannelMonitorRequestTemplateMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&ChannelMonitorRequestTemplateCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&ChannelMonitorRequestTemplateUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&ChannelMonitorRequestTemplateUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&ChannelMonitorRequestTemplateDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown ChannelMonitorRequestTemplate mutation op: %q", m.Op())
+ }
+}
+
// ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema.
type ErrorPassthroughRuleClient struct {
config
@@ -1760,6 +2788,171 @@ func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyReco
}
}
+// IdentityAdoptionDecisionClient is a client for the IdentityAdoptionDecision schema.
+type IdentityAdoptionDecisionClient struct {
+ config
+}
+
+// NewIdentityAdoptionDecisionClient returns a client for the IdentityAdoptionDecision from the given config.
+func NewIdentityAdoptionDecisionClient(c config) *IdentityAdoptionDecisionClient {
+ return &IdentityAdoptionDecisionClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `identityadoptiondecision.Hooks(f(g(h())))`.
+func (c *IdentityAdoptionDecisionClient) Use(hooks ...Hook) {
+ c.hooks.IdentityAdoptionDecision = append(c.hooks.IdentityAdoptionDecision, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `identityadoptiondecision.Intercept(f(g(h())))`.
+func (c *IdentityAdoptionDecisionClient) Intercept(interceptors ...Interceptor) {
+ c.inters.IdentityAdoptionDecision = append(c.inters.IdentityAdoptionDecision, interceptors...)
+}
+
+// Create returns a builder for creating a IdentityAdoptionDecision entity.
+func (c *IdentityAdoptionDecisionClient) Create() *IdentityAdoptionDecisionCreate {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpCreate)
+ return &IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of IdentityAdoptionDecision entities.
+func (c *IdentityAdoptionDecisionClient) CreateBulk(builders ...*IdentityAdoptionDecisionCreate) *IdentityAdoptionDecisionCreateBulk {
+ return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *IdentityAdoptionDecisionClient) MapCreateBulk(slice any, setFunc func(*IdentityAdoptionDecisionCreate, int)) *IdentityAdoptionDecisionCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &IdentityAdoptionDecisionCreateBulk{err: fmt.Errorf("calling to IdentityAdoptionDecisionClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*IdentityAdoptionDecisionCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Update() *IdentityAdoptionDecisionUpdate {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdate)
+ return &IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *IdentityAdoptionDecisionClient) UpdateOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecision(_m))
+ return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *IdentityAdoptionDecisionClient) UpdateOneID(id int64) *IdentityAdoptionDecisionUpdateOne {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecisionID(id))
+ return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Delete() *IdentityAdoptionDecisionDelete {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpDelete)
+ return &IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *IdentityAdoptionDecisionClient) DeleteOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *IdentityAdoptionDecisionClient) DeleteOneID(id int64) *IdentityAdoptionDecisionDeleteOne {
+ builder := c.Delete().Where(identityadoptiondecision.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &IdentityAdoptionDecisionDeleteOne{builder}
+}
+
+// Query returns a query builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Query() *IdentityAdoptionDecisionQuery {
+ return &IdentityAdoptionDecisionQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeIdentityAdoptionDecision},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a IdentityAdoptionDecision entity by its id.
+func (c *IdentityAdoptionDecisionClient) Get(ctx context.Context, id int64) (*IdentityAdoptionDecision, error) {
+ return c.Query().Where(identityadoptiondecision.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *IdentityAdoptionDecisionClient) GetX(ctx context.Context, id int64) *IdentityAdoptionDecision {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryPendingAuthSession queries the pending_auth_session edge of a IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) QueryPendingAuthSession(_m *IdentityAdoptionDecision) *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryIdentity queries the identity edge of a IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) QueryIdentity(_m *IdentityAdoptionDecision) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *IdentityAdoptionDecisionClient) Hooks() []Hook {
+ return c.hooks.IdentityAdoptionDecision
+}
+
+// Interceptors returns the client interceptors.
+func (c *IdentityAdoptionDecisionClient) Interceptors() []Interceptor {
+ return c.inters.IdentityAdoptionDecision
+}
+
+func (c *IdentityAdoptionDecisionClient) mutate(ctx context.Context, m *IdentityAdoptionDecisionMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown IdentityAdoptionDecision mutation op: %q", m.Op())
+ }
+}
+
// PaymentAuditLogClient is a client for the PaymentAuditLog schema.
type PaymentAuditLogClient struct {
config
@@ -2175,6 +3368,171 @@ func (c *PaymentProviderInstanceClient) mutate(ctx context.Context, m *PaymentPr
}
}
+// PendingAuthSessionClient is a client for the PendingAuthSession schema.
+type PendingAuthSessionClient struct {
+ config
+}
+
+// NewPendingAuthSessionClient returns a client for the PendingAuthSession from the given config.
+func NewPendingAuthSessionClient(c config) *PendingAuthSessionClient {
+ return &PendingAuthSessionClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `pendingauthsession.Hooks(f(g(h())))`.
+func (c *PendingAuthSessionClient) Use(hooks ...Hook) {
+ c.hooks.PendingAuthSession = append(c.hooks.PendingAuthSession, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `pendingauthsession.Intercept(f(g(h())))`.
+func (c *PendingAuthSessionClient) Intercept(interceptors ...Interceptor) {
+ c.inters.PendingAuthSession = append(c.inters.PendingAuthSession, interceptors...)
+}
+
+// Create returns a builder for creating a PendingAuthSession entity.
+func (c *PendingAuthSessionClient) Create() *PendingAuthSessionCreate {
+ mutation := newPendingAuthSessionMutation(c.config, OpCreate)
+ return &PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of PendingAuthSession entities.
+func (c *PendingAuthSessionClient) CreateBulk(builders ...*PendingAuthSessionCreate) *PendingAuthSessionCreateBulk {
+ return &PendingAuthSessionCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *PendingAuthSessionClient) MapCreateBulk(slice any, setFunc func(*PendingAuthSessionCreate, int)) *PendingAuthSessionCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &PendingAuthSessionCreateBulk{err: fmt.Errorf("calling to PendingAuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*PendingAuthSessionCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &PendingAuthSessionCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Update() *PendingAuthSessionUpdate {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdate)
+ return &PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *PendingAuthSessionClient) UpdateOne(_m *PendingAuthSession) *PendingAuthSessionUpdateOne {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSession(_m))
+ return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *PendingAuthSessionClient) UpdateOneID(id int64) *PendingAuthSessionUpdateOne {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSessionID(id))
+ return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Delete() *PendingAuthSessionDelete {
+ mutation := newPendingAuthSessionMutation(c.config, OpDelete)
+ return &PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *PendingAuthSessionClient) DeleteOne(_m *PendingAuthSession) *PendingAuthSessionDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *PendingAuthSessionClient) DeleteOneID(id int64) *PendingAuthSessionDeleteOne {
+ builder := c.Delete().Where(pendingauthsession.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &PendingAuthSessionDeleteOne{builder}
+}
+
+// Query returns a query builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Query() *PendingAuthSessionQuery {
+ return &PendingAuthSessionQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypePendingAuthSession},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a PendingAuthSession entity by its id.
+func (c *PendingAuthSessionClient) Get(ctx context.Context, id int64) (*PendingAuthSession, error) {
+ return c.Query().Where(pendingauthsession.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *PendingAuthSessionClient) GetX(ctx context.Context, id int64) *PendingAuthSession {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryTargetUser queries the target_user edge of a PendingAuthSession.
+func (c *PendingAuthSessionClient) QueryTargetUser(_m *PendingAuthSession) *UserQuery {
+ query := (&UserClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecision queries the adoption_decision edge of a PendingAuthSession.
+func (c *PendingAuthSessionClient) QueryAdoptionDecision(_m *PendingAuthSession) *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *PendingAuthSessionClient) Hooks() []Hook {
+ return c.hooks.PendingAuthSession
+}
+
+// Interceptors returns the client interceptors.
+func (c *PendingAuthSessionClient) Interceptors() []Interceptor {
+ return c.inters.PendingAuthSession
+}
+
+func (c *PendingAuthSessionClient) mutate(ctx context.Context, m *PendingAuthSessionMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown PendingAuthSession mutation op: %q", m.Op())
+ }
+}
+
// PromoCodeClient is a client for the PromoCode schema.
type PromoCodeClient struct {
config
@@ -3951,6 +5309,38 @@ func (c *UserClient) QueryPaymentOrders(_m *User) *PaymentOrderQuery {
return query
}
+// QueryAuthIdentities queries the auth_identities edge of a User.
+func (c *UserClient) QueryAuthIdentities(_m *User) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryPendingAuthSessions queries the pending_auth_sessions edge of a User.
+func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, id),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: c.config}).Query()
@@ -4628,20 +6018,24 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
// hooks and interceptors per client, for fast access.
type (
hooks struct {
- APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
- PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
- SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
- UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
- UserAttributeValue, UserSubscription []ent.Hook
+ APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
+ AuthIdentityChannel, ChannelMonitor, ChannelMonitorDailyRollup,
+ ChannelMonitorHistory, ChannelMonitorRequestTemplate, ErrorPassthroughRule,
+ Group, IdempotencyRecord, IdentityAdoptionDecision, PaymentAuditLog,
+ PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
+ PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
+ TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
+ UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
- APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder,
- PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
- SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile,
- UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
- UserAttributeValue, UserSubscription []ent.Interceptor
+ APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
+ AuthIdentityChannel, ChannelMonitor, ChannelMonitorDailyRollup,
+ ChannelMonitorHistory, ChannelMonitorRequestTemplate, ErrorPassthroughRule,
+ Group, IdempotencyRecord, IdentityAdoptionDecision, PaymentAuditLog,
+ PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
+ PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
+ TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
+ UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
}
)
diff --git a/backend/ent/ent.go b/backend/ent/ent.go
index 96ed5e03..c9fcc314 100644
--- a/backend/ent/ent.go
+++ b/backend/ent/ent.go
@@ -17,12 +17,20 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -98,32 +106,40 @@ var (
func checkColumn(t, c string) error {
initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
- apikey.Table: apikey.ValidColumn,
- account.Table: account.ValidColumn,
- accountgroup.Table: accountgroup.ValidColumn,
- announcement.Table: announcement.ValidColumn,
- announcementread.Table: announcementread.ValidColumn,
- errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
- group.Table: group.ValidColumn,
- idempotencyrecord.Table: idempotencyrecord.ValidColumn,
- paymentauditlog.Table: paymentauditlog.ValidColumn,
- paymentorder.Table: paymentorder.ValidColumn,
- paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
- promocode.Table: promocode.ValidColumn,
- promocodeusage.Table: promocodeusage.ValidColumn,
- proxy.Table: proxy.ValidColumn,
- redeemcode.Table: redeemcode.ValidColumn,
- securitysecret.Table: securitysecret.ValidColumn,
- setting.Table: setting.ValidColumn,
- subscriptionplan.Table: subscriptionplan.ValidColumn,
- tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
- usagecleanuptask.Table: usagecleanuptask.ValidColumn,
- usagelog.Table: usagelog.ValidColumn,
- user.Table: user.ValidColumn,
- userallowedgroup.Table: userallowedgroup.ValidColumn,
- userattributedefinition.Table: userattributedefinition.ValidColumn,
- userattributevalue.Table: userattributevalue.ValidColumn,
- usersubscription.Table: usersubscription.ValidColumn,
+ apikey.Table: apikey.ValidColumn,
+ account.Table: account.ValidColumn,
+ accountgroup.Table: accountgroup.ValidColumn,
+ announcement.Table: announcement.ValidColumn,
+ announcementread.Table: announcementread.ValidColumn,
+ authidentity.Table: authidentity.ValidColumn,
+ authidentitychannel.Table: authidentitychannel.ValidColumn,
+ channelmonitor.Table: channelmonitor.ValidColumn,
+ channelmonitordailyrollup.Table: channelmonitordailyrollup.ValidColumn,
+ channelmonitorhistory.Table: channelmonitorhistory.ValidColumn,
+ channelmonitorrequesttemplate.Table: channelmonitorrequesttemplate.ValidColumn,
+ errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
+ group.Table: group.ValidColumn,
+ idempotencyrecord.Table: idempotencyrecord.ValidColumn,
+ identityadoptiondecision.Table: identityadoptiondecision.ValidColumn,
+ paymentauditlog.Table: paymentauditlog.ValidColumn,
+ paymentorder.Table: paymentorder.ValidColumn,
+ paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
+ pendingauthsession.Table: pendingauthsession.ValidColumn,
+ promocode.Table: promocode.ValidColumn,
+ promocodeusage.Table: promocodeusage.ValidColumn,
+ proxy.Table: proxy.ValidColumn,
+ redeemcode.Table: redeemcode.ValidColumn,
+ securitysecret.Table: securitysecret.ValidColumn,
+ setting.Table: setting.ValidColumn,
+ subscriptionplan.Table: subscriptionplan.ValidColumn,
+ tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
+ usagecleanuptask.Table: usagecleanuptask.ValidColumn,
+ usagelog.Table: usagelog.ValidColumn,
+ user.Table: user.ValidColumn,
+ userallowedgroup.Table: userallowedgroup.ValidColumn,
+ userattributedefinition.Table: userattributedefinition.ValidColumn,
+ userattributevalue.Table: userattributevalue.ValidColumn,
+ usersubscription.Table: usersubscription.ValidColumn,
})
})
return columnCheck(t, c)
diff --git a/backend/ent/group.go b/backend/ent/group.go
index f10b50c3..5d9ae2ed 100644
--- a/backend/ent/group.go
+++ b/backend/ent/group.go
@@ -79,6 +79,8 @@ type Group struct {
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
// OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型
MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
+ // 分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流
+ RpmLimit int `json:"rpm_limit,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -191,7 +193,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
- case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
+ case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit:
values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
values[i] = new(sql.NullString)
@@ -414,6 +416,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err)
}
}
+ case group.FieldRpmLimit:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field rpm_limit", values[i])
+ } else if value.Valid {
+ _m.RpmLimit = int(value.Int64)
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -599,6 +607,9 @@ func (_m *Group) String() string {
builder.WriteString(", ")
builder.WriteString("messages_dispatch_model_config=")
builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig))
+ builder.WriteString(", ")
+ builder.WriteString("rpm_limit=")
+ builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit))
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go
index b1371630..24bd9c13 100644
--- a/backend/ent/group/group.go
+++ b/backend/ent/group/group.go
@@ -76,6 +76,8 @@ const (
FieldDefaultMappedModel = "default_mapped_model"
// FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database.
FieldMessagesDispatchModelConfig = "messages_dispatch_model_config"
+ // FieldRpmLimit holds the string denoting the rpm_limit field in the database.
+ FieldRpmLimit = "rpm_limit"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -181,6 +183,7 @@ var Columns = []string{
FieldRequirePrivacySet,
FieldDefaultMappedModel,
FieldMessagesDispatchModelConfig,
+ FieldRpmLimit,
}
var (
@@ -258,6 +261,8 @@ var (
DefaultMappedModelValidator func(string) error
// DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field.
DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig
+ // DefaultRpmLimit holds the default value on creation for the "rpm_limit" field.
+ DefaultRpmLimit int
)
// OrderOption defines the ordering options for the Group queries.
@@ -403,6 +408,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
}
+// ByRpmLimit orders the results by the rpm_limit field.
+func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRpmLimit, opts...).ToFunc()
+}
+
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go
index cba2ce5f..2814d130 100644
--- a/backend/ent/group/where.go
+++ b/backend/ent/group/where.go
@@ -190,6 +190,11 @@ func DefaultMappedModel(v string) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
}
+// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ.
+func RpmLimit(v int) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldRpmLimit, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -1320,6 +1325,46 @@ func DefaultMappedModelContainsFold(v string) predicate.Group {
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
}
+// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field.
+func RpmLimitEQ(v int) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldRpmLimit, v))
+}
+
+// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field.
+func RpmLimitNEQ(v int) predicate.Group {
+ return predicate.Group(sql.FieldNEQ(FieldRpmLimit, v))
+}
+
+// RpmLimitIn applies the In predicate on the "rpm_limit" field.
+func RpmLimitIn(vs ...int) predicate.Group {
+ return predicate.Group(sql.FieldIn(FieldRpmLimit, vs...))
+}
+
+// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field.
+func RpmLimitNotIn(vs ...int) predicate.Group {
+ return predicate.Group(sql.FieldNotIn(FieldRpmLimit, vs...))
+}
+
+// RpmLimitGT applies the GT predicate on the "rpm_limit" field.
+func RpmLimitGT(v int) predicate.Group {
+ return predicate.Group(sql.FieldGT(FieldRpmLimit, v))
+}
+
+// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field.
+func RpmLimitGTE(v int) predicate.Group {
+ return predicate.Group(sql.FieldGTE(FieldRpmLimit, v))
+}
+
+// RpmLimitLT applies the LT predicate on the "rpm_limit" field.
+func RpmLimitLT(v int) predicate.Group {
+ return predicate.Group(sql.FieldLT(FieldRpmLimit, v))
+}
+
+// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field.
+func RpmLimitLTE(v int) predicate.Group {
+ return predicate.Group(sql.FieldLTE(FieldRpmLimit, v))
+}
+
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {
diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go
index f412fa40..20ea0a0f 100644
--- a/backend/ent/group_create.go
+++ b/backend/ent/group_create.go
@@ -425,6 +425,20 @@ func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe
return _c
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (_c *GroupCreate) SetRpmLimit(v int) *GroupCreate {
+ _c.mutation.SetRpmLimit(v)
+ return _c
+}
+
+// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
+func (_c *GroupCreate) SetNillableRpmLimit(v *int) *GroupCreate {
+ if v != nil {
+ _c.SetRpmLimit(*v)
+ }
+ return _c
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -630,6 +644,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultMessagesDispatchModelConfig
_c.mutation.SetMessagesDispatchModelConfig(v)
}
+ if _, ok := _c.mutation.RpmLimit(); !ok {
+ v := group.DefaultRpmLimit
+ _c.mutation.SetRpmLimit(v)
+ }
return nil
}
@@ -717,6 +735,9 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)}
}
+ if _, ok := _c.mutation.RpmLimit(); !ok {
+ return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "Group.rpm_limit"`)}
+ }
return nil
}
@@ -864,6 +885,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
_node.MessagesDispatchModelConfig = value
}
+ if value, ok := _c.mutation.RpmLimit(); ok {
+ _spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
+ _node.RpmLimit = value
+ }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1500,6 +1525,24 @@ func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert {
return u
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (u *GroupUpsert) SetRpmLimit(v int) *GroupUpsert {
+ u.Set(group.FieldRpmLimit, v)
+ return u
+}
+
+// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
+func (u *GroupUpsert) UpdateRpmLimit() *GroupUpsert {
+ u.SetExcluded(group.FieldRpmLimit)
+ return u
+}
+
+// AddRpmLimit adds v to the "rpm_limit" field.
+func (u *GroupUpsert) AddRpmLimit(v int) *GroupUpsert {
+ u.Add(group.FieldRpmLimit, v)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -2105,6 +2148,27 @@ func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne {
})
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (u *GroupUpsertOne) SetRpmLimit(v int) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetRpmLimit(v)
+ })
+}
+
+// AddRpmLimit adds v to the "rpm_limit" field.
+func (u *GroupUpsertOne) AddRpmLimit(v int) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.AddRpmLimit(v)
+ })
+}
+
+// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
+func (u *GroupUpsertOne) UpdateRpmLimit() *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateRpmLimit()
+ })
+}
+
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -2876,6 +2940,27 @@ func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk {
})
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (u *GroupUpsertBulk) SetRpmLimit(v int) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetRpmLimit(v)
+ })
+}
+
+// AddRpmLimit adds v to the "rpm_limit" field.
+func (u *GroupUpsertBulk) AddRpmLimit(v int) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.AddRpmLimit(v)
+ })
+}
+
+// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
+func (u *GroupUpsertBulk) UpdateRpmLimit() *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateRpmLimit()
+ })
+}
+
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go
index 7b6d6193..cc14f897 100644
--- a/backend/ent/group_update.go
+++ b/backend/ent/group_update.go
@@ -567,6 +567,27 @@ func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe
return _u
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (_u *GroupUpdate) SetRpmLimit(v int) *GroupUpdate {
+ _u.mutation.ResetRpmLimit()
+ _u.mutation.SetRpmLimit(v)
+ return _u
+}
+
+// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
+func (_u *GroupUpdate) SetNillableRpmLimit(v *int) *GroupUpdate {
+ if v != nil {
+ _u.SetRpmLimit(*v)
+ }
+ return _u
+}
+
+// AddRpmLimit adds value to the "rpm_limit" field.
+func (_u *GroupUpdate) AddRpmLimit(v int) *GroupUpdate {
+ _u.mutation.AddRpmLimit(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1030,6 +1051,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
}
+ if value, ok := _u.mutation.RpmLimit(); ok {
+ _spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedRpmLimit(); ok {
+ _spec.AddField(group.FieldRpmLimit, field.TypeInt, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1875,6 +1902,27 @@ func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenA
return _u
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (_u *GroupUpdateOne) SetRpmLimit(v int) *GroupUpdateOne {
+ _u.mutation.ResetRpmLimit()
+ _u.mutation.SetRpmLimit(v)
+ return _u
+}
+
+// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
+func (_u *GroupUpdateOne) SetNillableRpmLimit(v *int) *GroupUpdateOne {
+ if v != nil {
+ _u.SetRpmLimit(*v)
+ }
+ return _u
+}
+
+// AddRpmLimit adds value to the "rpm_limit" field.
+func (_u *GroupUpdateOne) AddRpmLimit(v int) *GroupUpdateOne {
+ _u.mutation.AddRpmLimit(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -2368,6 +2416,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
_spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
}
+ if value, ok := _u.mutation.RpmLimit(); ok {
+ _spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedRpmLimit(); ok {
+ _spec.AddField(group.FieldRpmLimit, field.TypeInt, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go
index 199dacea..414eba24 100644
--- a/backend/ent/hook/hook.go
+++ b/backend/ent/hook/hook.go
@@ -69,6 +69,78 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m)
}
+// The AuthIdentityFunc type is an adapter to allow the use of ordinary
+// function as AuthIdentity mutator.
+type AuthIdentityFunc func(context.Context, *ent.AuthIdentityMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AuthIdentityFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AuthIdentityMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityMutation", m)
+}
+
+// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary
+// function as AuthIdentityChannel mutator.
+type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AuthIdentityChannelFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AuthIdentityChannelMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityChannelMutation", m)
+}
+
+// The ChannelMonitorFunc type is an adapter to allow the use of ordinary
+// function as ChannelMonitor mutator.
+type ChannelMonitorFunc func(context.Context, *ent.ChannelMonitorMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f ChannelMonitorFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.ChannelMonitorMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorMutation", m)
+}
+
+// The ChannelMonitorDailyRollupFunc type is an adapter to allow the use of ordinary
+// function as ChannelMonitorDailyRollup mutator.
+type ChannelMonitorDailyRollupFunc func(context.Context, *ent.ChannelMonitorDailyRollupMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f ChannelMonitorDailyRollupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.ChannelMonitorDailyRollupMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorDailyRollupMutation", m)
+}
+
+// The ChannelMonitorHistoryFunc type is an adapter to allow the use of ordinary
+// function as ChannelMonitorHistory mutator.
+type ChannelMonitorHistoryFunc func(context.Context, *ent.ChannelMonitorHistoryMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f ChannelMonitorHistoryFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.ChannelMonitorHistoryMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorHistoryMutation", m)
+}
+
+// The ChannelMonitorRequestTemplateFunc type is an adapter to allow the use of ordinary
+// function as ChannelMonitorRequestTemplate mutator.
+type ChannelMonitorRequestTemplateFunc func(context.Context, *ent.ChannelMonitorRequestTemplateMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f ChannelMonitorRequestTemplateFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.ChannelMonitorRequestTemplateMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorRequestTemplateMutation", m)
+}
+
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary
// function as ErrorPassthroughRule mutator.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error)
@@ -105,6 +177,18 @@ func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m)
}
+// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary
+// function as IdentityAdoptionDecision mutator.
+type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f IdentityAdoptionDecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.IdentityAdoptionDecisionMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdentityAdoptionDecisionMutation", m)
+}
+
// The PaymentAuditLogFunc type is an adapter to allow the use of ordinary
// function as PaymentAuditLog mutator.
type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogMutation) (ent.Value, error)
@@ -141,6 +225,18 @@ func (f PaymentProviderInstanceFunc) Mutate(ctx context.Context, m ent.Mutation)
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentProviderInstanceMutation", m)
}
+// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary
+// function as PendingAuthSession mutator.
+type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f PendingAuthSessionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.PendingAuthSessionMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PendingAuthSessionMutation", m)
+}
+
// The PromoCodeFunc type is an adapter to allow the use of ordinary
// function as PromoCode mutator.
type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error)
diff --git a/backend/ent/identityadoptiondecision.go b/backend/ent/identityadoptiondecision.go
new file mode 100644
index 00000000..ecaee65c
--- /dev/null
+++ b/backend/ent/identityadoptiondecision.go
@@ -0,0 +1,223 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+)
+
+// IdentityAdoptionDecision is the model entity for the IdentityAdoptionDecision schema.
+type IdentityAdoptionDecision struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // PendingAuthSessionID holds the value of the "pending_auth_session_id" field.
+ PendingAuthSessionID int64 `json:"pending_auth_session_id,omitempty"`
+ // IdentityID holds the value of the "identity_id" field.
+ IdentityID *int64 `json:"identity_id,omitempty"`
+ // AdoptDisplayName holds the value of the "adopt_display_name" field.
+ AdoptDisplayName bool `json:"adopt_display_name,omitempty"`
+ // AdoptAvatar holds the value of the "adopt_avatar" field.
+ AdoptAvatar bool `json:"adopt_avatar,omitempty"`
+ // DecidedAt holds the value of the "decided_at" field.
+ DecidedAt time.Time `json:"decided_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the IdentityAdoptionDecisionQuery when eager-loading is set.
+ Edges IdentityAdoptionDecisionEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// IdentityAdoptionDecisionEdges holds the relations/edges for other nodes in the graph.
+type IdentityAdoptionDecisionEdges struct {
+ // PendingAuthSession holds the value of the pending_auth_session edge.
+ PendingAuthSession *PendingAuthSession `json:"pending_auth_session,omitempty"`
+ // Identity holds the value of the identity edge.
+ Identity *AuthIdentity `json:"identity,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [2]bool
+}
+
+// PendingAuthSessionOrErr returns the PendingAuthSession value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e IdentityAdoptionDecisionEdges) PendingAuthSessionOrErr() (*PendingAuthSession, error) {
+ if e.PendingAuthSession != nil {
+ return e.PendingAuthSession, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: pendingauthsession.Label}
+ }
+ return nil, &NotLoadedError{edge: "pending_auth_session"}
+}
+
+// IdentityOrErr returns the Identity value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e IdentityAdoptionDecisionEdges) IdentityOrErr() (*AuthIdentity, error) {
+ if e.Identity != nil {
+ return e.Identity, nil
+ } else if e.loadedTypes[1] {
+ return nil, &NotFoundError{label: authidentity.Label}
+ }
+ return nil, &NotLoadedError{edge: "identity"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*IdentityAdoptionDecision) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case identityadoptiondecision.FieldAdoptDisplayName, identityadoptiondecision.FieldAdoptAvatar:
+ values[i] = new(sql.NullBool)
+ case identityadoptiondecision.FieldID, identityadoptiondecision.FieldPendingAuthSessionID, identityadoptiondecision.FieldIdentityID:
+ values[i] = new(sql.NullInt64)
+ case identityadoptiondecision.FieldCreatedAt, identityadoptiondecision.FieldUpdatedAt, identityadoptiondecision.FieldDecidedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the IdentityAdoptionDecision fields.
+func (_m *IdentityAdoptionDecision) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case identityadoptiondecision.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case identityadoptiondecision.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case identityadoptiondecision.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field pending_auth_session_id", values[i])
+ } else if value.Valid {
+ _m.PendingAuthSessionID = value.Int64
+ }
+ case identityadoptiondecision.FieldIdentityID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field identity_id", values[i])
+ } else if value.Valid {
+ _m.IdentityID = new(int64)
+ *_m.IdentityID = value.Int64
+ }
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field adopt_display_name", values[i])
+ } else if value.Valid {
+ _m.AdoptDisplayName = value.Bool
+ }
+ case identityadoptiondecision.FieldAdoptAvatar:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field adopt_avatar", values[i])
+ } else if value.Valid {
+ _m.AdoptAvatar = value.Bool
+ }
+ case identityadoptiondecision.FieldDecidedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field decided_at", values[i])
+ } else if value.Valid {
+ _m.DecidedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the IdentityAdoptionDecision.
+// This includes values selected through modifiers, order, etc.
+func (_m *IdentityAdoptionDecision) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryPendingAuthSession queries the "pending_auth_session" edge of the IdentityAdoptionDecision entity.
+func (_m *IdentityAdoptionDecision) QueryPendingAuthSession() *PendingAuthSessionQuery {
+ return NewIdentityAdoptionDecisionClient(_m.config).QueryPendingAuthSession(_m)
+}
+
+// QueryIdentity queries the "identity" edge of the IdentityAdoptionDecision entity.
+func (_m *IdentityAdoptionDecision) QueryIdentity() *AuthIdentityQuery {
+ return NewIdentityAdoptionDecisionClient(_m.config).QueryIdentity(_m)
+}
+
+// Update returns a builder for updating this IdentityAdoptionDecision.
+// Note that you need to call IdentityAdoptionDecision.Unwrap() before calling this method if this IdentityAdoptionDecision
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *IdentityAdoptionDecision) Update() *IdentityAdoptionDecisionUpdateOne {
+ return NewIdentityAdoptionDecisionClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the IdentityAdoptionDecision entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *IdentityAdoptionDecision) Unwrap() *IdentityAdoptionDecision {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: IdentityAdoptionDecision is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *IdentityAdoptionDecision) String() string {
+ var builder strings.Builder
+ builder.WriteString("IdentityAdoptionDecision(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("pending_auth_session_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.PendingAuthSessionID))
+ builder.WriteString(", ")
+ if v := _m.IdentityID; v != nil {
+ builder.WriteString("identity_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("adopt_display_name=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AdoptDisplayName))
+ builder.WriteString(", ")
+ builder.WriteString("adopt_avatar=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AdoptAvatar))
+ builder.WriteString(", ")
+ builder.WriteString("decided_at=")
+ builder.WriteString(_m.DecidedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// IdentityAdoptionDecisions is a parsable slice of IdentityAdoptionDecision.
+type IdentityAdoptionDecisions []*IdentityAdoptionDecision
diff --git a/backend/ent/identityadoptiondecision/identityadoptiondecision.go b/backend/ent/identityadoptiondecision/identityadoptiondecision.go
new file mode 100644
index 00000000..93adaf73
--- /dev/null
+++ b/backend/ent/identityadoptiondecision/identityadoptiondecision.go
@@ -0,0 +1,159 @@
+// Code generated by ent, DO NOT EDIT.
+
+package identityadoptiondecision
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the identityadoptiondecision type in the database.
+ Label = "identity_adoption_decision"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldPendingAuthSessionID holds the string denoting the pending_auth_session_id field in the database.
+ FieldPendingAuthSessionID = "pending_auth_session_id"
+ // FieldIdentityID holds the string denoting the identity_id field in the database.
+ FieldIdentityID = "identity_id"
+ // FieldAdoptDisplayName holds the string denoting the adopt_display_name field in the database.
+ FieldAdoptDisplayName = "adopt_display_name"
+ // FieldAdoptAvatar holds the string denoting the adopt_avatar field in the database.
+ FieldAdoptAvatar = "adopt_avatar"
+ // FieldDecidedAt holds the string denoting the decided_at field in the database.
+ FieldDecidedAt = "decided_at"
+ // EdgePendingAuthSession holds the string denoting the pending_auth_session edge name in mutations.
+ EdgePendingAuthSession = "pending_auth_session"
+ // EdgeIdentity holds the string denoting the identity edge name in mutations.
+ EdgeIdentity = "identity"
+ // Table holds the table name of the identityadoptiondecision in the database.
+ Table = "identity_adoption_decisions"
+ // PendingAuthSessionTable is the table that holds the pending_auth_session relation/edge.
+ PendingAuthSessionTable = "identity_adoption_decisions"
+ // PendingAuthSessionInverseTable is the table name for the PendingAuthSession entity.
+ // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
+ PendingAuthSessionInverseTable = "pending_auth_sessions"
+ // PendingAuthSessionColumn is the table column denoting the pending_auth_session relation/edge.
+ PendingAuthSessionColumn = "pending_auth_session_id"
+ // IdentityTable is the table that holds the identity relation/edge.
+ IdentityTable = "identity_adoption_decisions"
+ // IdentityInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ IdentityInverseTable = "auth_identities"
+ // IdentityColumn is the table column denoting the identity relation/edge.
+ IdentityColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for identityadoptiondecision fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldPendingAuthSessionID,
+ FieldIdentityID,
+ FieldAdoptDisplayName,
+ FieldAdoptAvatar,
+ FieldDecidedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // DefaultAdoptDisplayName holds the default value on creation for the "adopt_display_name" field.
+ DefaultAdoptDisplayName bool
+ // DefaultAdoptAvatar holds the default value on creation for the "adopt_avatar" field.
+ DefaultAdoptAvatar bool
+ // DefaultDecidedAt holds the default value on creation for the "decided_at" field.
+ DefaultDecidedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the IdentityAdoptionDecision queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByPendingAuthSessionID orders the results by the pending_auth_session_id field.
+func ByPendingAuthSessionID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPendingAuthSessionID, opts...).ToFunc()
+}
+
+// ByIdentityID orders the results by the identity_id field.
+func ByIdentityID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIdentityID, opts...).ToFunc()
+}
+
+// ByAdoptDisplayName orders the results by the adopt_display_name field.
+func ByAdoptDisplayName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAdoptDisplayName, opts...).ToFunc()
+}
+
+// ByAdoptAvatar orders the results by the adopt_avatar field.
+func ByAdoptAvatar(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAdoptAvatar, opts...).ToFunc()
+}
+
+// ByDecidedAt orders the results by the decided_at field.
+func ByDecidedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDecidedAt, opts...).ToFunc()
+}
+
+// ByPendingAuthSessionField orders the results by pending_auth_session field.
+func ByPendingAuthSessionField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByIdentityField orders the results by identity field.
+func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newPendingAuthSessionStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(PendingAuthSessionInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn),
+ )
+}
+func newIdentityStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(IdentityInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+}
diff --git a/backend/ent/identityadoptiondecision/where.go b/backend/ent/identityadoptiondecision/where.go
new file mode 100644
index 00000000..1968f175
--- /dev/null
+++ b/backend/ent/identityadoptiondecision/where.go
@@ -0,0 +1,342 @@
+// Code generated by ent, DO NOT EDIT.
+
+package identityadoptiondecision
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// PendingAuthSessionID applies equality check predicate on the "pending_auth_session_id" field. It's identical to PendingAuthSessionIDEQ.
+func PendingAuthSessionID(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v))
+}
+
+// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ.
+func IdentityID(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// AdoptDisplayName applies equality check predicate on the "adopt_display_name" field. It's identical to AdoptDisplayNameEQ.
+func AdoptDisplayName(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptAvatar applies equality check predicate on the "adopt_avatar" field. It's identical to AdoptAvatarEQ.
+func AdoptAvatar(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v))
+}
+
+// DecidedAt applies equality check predicate on the "decided_at" field. It's identical to DecidedAtEQ.
+func DecidedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// PendingAuthSessionIDEQ applies the EQ predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v))
+}
+
+// PendingAuthSessionIDNEQ applies the NEQ predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDNEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldPendingAuthSessionID, v))
+}
+
+// PendingAuthSessionIDIn applies the In predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldPendingAuthSessionID, vs...))
+}
+
+// PendingAuthSessionIDNotIn applies the NotIn predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldPendingAuthSessionID, vs...))
+}
+
+// IdentityIDEQ applies the EQ predicate on the "identity_id" field.
+func IdentityIDEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field.
+func IdentityIDNEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldIdentityID, v))
+}
+
+// IdentityIDIn applies the In predicate on the "identity_id" field.
+func IdentityIDIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field.
+func IdentityIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDIsNil applies the IsNil predicate on the "identity_id" field.
+func IdentityIDIsNil() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIsNull(FieldIdentityID))
+}
+
+// IdentityIDNotNil applies the NotNil predicate on the "identity_id" field.
+func IdentityIDNotNil() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotNull(FieldIdentityID))
+}
+
+// AdoptDisplayNameEQ applies the EQ predicate on the "adopt_display_name" field.
+func AdoptDisplayNameEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptDisplayNameNEQ applies the NEQ predicate on the "adopt_display_name" field.
+func AdoptDisplayNameNEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptAvatarEQ applies the EQ predicate on the "adopt_avatar" field.
+func AdoptAvatarEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v))
+}
+
+// AdoptAvatarNEQ applies the NEQ predicate on the "adopt_avatar" field.
+func AdoptAvatarNEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptAvatar, v))
+}
+
+// DecidedAtEQ applies the EQ predicate on the "decided_at" field.
+func DecidedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v))
+}
+
+// DecidedAtNEQ applies the NEQ predicate on the "decided_at" field.
+func DecidedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldDecidedAt, v))
+}
+
+// DecidedAtIn applies the In predicate on the "decided_at" field.
+func DecidedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldDecidedAt, vs...))
+}
+
+// DecidedAtNotIn applies the NotIn predicate on the "decided_at" field.
+func DecidedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldDecidedAt, vs...))
+}
+
+// DecidedAtGT applies the GT predicate on the "decided_at" field.
+func DecidedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldDecidedAt, v))
+}
+
+// DecidedAtGTE applies the GTE predicate on the "decided_at" field.
+func DecidedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldDecidedAt, v))
+}
+
+// DecidedAtLT applies the LT predicate on the "decided_at" field.
+func DecidedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldDecidedAt, v))
+}
+
+// DecidedAtLTE applies the LTE predicate on the "decided_at" field.
+func DecidedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldDecidedAt, v))
+}
+
+// HasPendingAuthSession applies the HasEdge predicate on the "pending_auth_session" edge.
+func HasPendingAuthSession() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasPendingAuthSessionWith applies the HasEdge predicate on the "pending_auth_session" edge with a given conditions (other predicates).
+func HasPendingAuthSessionWith(preds ...predicate.PendingAuthSession) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := newPendingAuthSessionStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasIdentity applies the HasEdge predicate on the "identity" edge.
+func HasIdentity() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates).
+func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := newIdentityStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.NotPredicates(p))
+}
diff --git a/backend/ent/identityadoptiondecision_create.go b/backend/ent/identityadoptiondecision_create.go
new file mode 100644
index 00000000..491ba9f9
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_create.go
@@ -0,0 +1,843 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+)
+
+// IdentityAdoptionDecisionCreate is the builder for creating a IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionCreate struct {
+ config
+ mutation *IdentityAdoptionDecisionMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetCreatedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableCreatedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableUpdatedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetPendingAuthSessionID(v)
+ return _c
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_c *IdentityAdoptionDecisionCreate) SetIdentityID(v int64) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetIdentityID(v)
+ return _c
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetIdentityID(*v)
+ }
+ return _c
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_c *IdentityAdoptionDecisionCreate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetAdoptDisplayName(v)
+ return _c
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetAdoptDisplayName(*v)
+ }
+ return _c
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_c *IdentityAdoptionDecisionCreate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetAdoptAvatar(v)
+ return _c
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetAdoptAvatar(*v)
+ }
+ return _c
+}
+
+// SetDecidedAt sets the "decided_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetDecidedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetDecidedAt(v)
+ return _c
+}
+
+// SetNillableDecidedAt sets the "decided_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableDecidedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetDecidedAt(*v)
+ }
+ return _c
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionCreate {
+ return _c.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_c *IdentityAdoptionDecisionCreate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionCreate {
+ return _c.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_c *IdentityAdoptionDecisionCreate) Mutation() *IdentityAdoptionDecisionMutation {
+ return _c.mutation
+}
+
+// Save creates the IdentityAdoptionDecision in the database.
+func (_c *IdentityAdoptionDecisionCreate) Save(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *IdentityAdoptionDecisionCreate) SaveX(ctx context.Context) *IdentityAdoptionDecision {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *IdentityAdoptionDecisionCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *IdentityAdoptionDecisionCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := identityadoptiondecision.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.AdoptDisplayName(); !ok {
+ v := identityadoptiondecision.DefaultAdoptDisplayName
+ _c.mutation.SetAdoptDisplayName(v)
+ }
+ if _, ok := _c.mutation.AdoptAvatar(); !ok {
+ v := identityadoptiondecision.DefaultAdoptAvatar
+ _c.mutation.SetAdoptAvatar(v)
+ }
+ if _, ok := _c.mutation.DecidedAt(); !ok {
+ v := identityadoptiondecision.DefaultDecidedAt()
+ _c.mutation.SetDecidedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *IdentityAdoptionDecisionCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.updated_at"`)}
+ }
+ if _, ok := _c.mutation.PendingAuthSessionID(); !ok {
+ return &ValidationError{Name: "pending_auth_session_id", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.pending_auth_session_id"`)}
+ }
+ if _, ok := _c.mutation.AdoptDisplayName(); !ok {
+ return &ValidationError{Name: "adopt_display_name", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_display_name"`)}
+ }
+ if _, ok := _c.mutation.AdoptAvatar(); !ok {
+ return &ValidationError{Name: "adopt_avatar", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_avatar"`)}
+ }
+ if _, ok := _c.mutation.DecidedAt(); !ok {
+ return &ValidationError{Name: "decided_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.decided_at"`)}
+ }
+ if len(_c.mutation.PendingAuthSessionIDs()) == 0 {
+ return &ValidationError{Name: "pending_auth_session", err: errors.New(`ent: missing required edge "IdentityAdoptionDecision.pending_auth_session"`)}
+ }
+ return nil
+}
+
+func (_c *IdentityAdoptionDecisionCreate) sqlSave(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *IdentityAdoptionDecisionCreate) createSpec() (*IdentityAdoptionDecision, *sqlgraph.CreateSpec) {
+ var (
+ _node = &IdentityAdoptionDecision{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ _node.AdoptDisplayName = value
+ }
+ if value, ok := _c.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ _node.AdoptAvatar = value
+ }
+ if value, ok := _c.mutation.DecidedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldDecidedAt, field.TypeTime, value)
+ _node.DecidedAt = value
+ }
+ if nodes := _c.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.PendingAuthSessionID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.IdentityID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.IdentityAdoptionDecision.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.IdentityAdoptionDecisionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreate) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertOne {
+ _c.conflict = opts
+ return &IdentityAdoptionDecisionUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreate) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &IdentityAdoptionDecisionUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // IdentityAdoptionDecisionUpsertOne is the builder for "upsert"-ing
+ // one IdentityAdoptionDecision node.
+ IdentityAdoptionDecisionUpsertOne struct {
+ create *IdentityAdoptionDecisionCreate
+ }
+
+ // IdentityAdoptionDecisionUpsert is the "OnConflict" setter.
+ IdentityAdoptionDecisionUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsert) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldUpdatedAt)
+ return u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsert) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldPendingAuthSessionID, v)
+ return u
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldPendingAuthSessionID)
+ return u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsert) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldIdentityID, v)
+ return u
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateIdentityID() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldIdentityID)
+ return u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsert) ClearIdentityID() *IdentityAdoptionDecisionUpsert {
+ u.SetNull(identityadoptiondecision.FieldIdentityID)
+ return u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsert) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldAdoptDisplayName, v)
+ return u
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldAdoptDisplayName)
+ return u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsert) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldAdoptAvatar, v)
+ return u
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldAdoptAvatar)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateNewValues() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldCreatedAt)
+ }
+ if _, exists := u.create.mutation.DecidedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldDecidedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertOne) Ignore() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *IdentityAdoptionDecisionUpsertOne) DoNothing() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreate.OnConflict
+// documentation for more info.
+func (u *IdentityAdoptionDecisionUpsertOne) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&IdentityAdoptionDecisionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetPendingAuthSessionID(v)
+ })
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdatePendingAuthSessionID()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateIdentityID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) ClearIdentityID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.ClearIdentityID()
+ })
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptDisplayName(v)
+ })
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptDisplayName()
+ })
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptAvatar(v)
+ })
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptAvatar()
+ })
+}
+
+// Exec executes the query.
+func (u *IdentityAdoptionDecisionUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for IdentityAdoptionDecisionCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *IdentityAdoptionDecisionUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// IdentityAdoptionDecisionCreateBulk is the builder for creating many IdentityAdoptionDecision entities in bulk.
+type IdentityAdoptionDecisionCreateBulk struct {
+ config
+ err error
+ builders []*IdentityAdoptionDecisionCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the IdentityAdoptionDecision entities in the database.
+func (_c *IdentityAdoptionDecisionCreateBulk) Save(ctx context.Context) ([]*IdentityAdoptionDecision, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*IdentityAdoptionDecision, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*IdentityAdoptionDecisionMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreateBulk) SaveX(ctx context.Context) []*IdentityAdoptionDecision {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *IdentityAdoptionDecisionCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.IdentityAdoptionDecision.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.IdentityAdoptionDecisionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertBulk {
+ _c.conflict = opts
+ return &IdentityAdoptionDecisionUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreateBulk) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &IdentityAdoptionDecisionUpsertBulk{
+ create: _c,
+ }
+}
+
+// IdentityAdoptionDecisionUpsertBulk is the builder for "upsert"-ing
+// a bulk of IdentityAdoptionDecision nodes.
+type IdentityAdoptionDecisionUpsertBulk struct {
+ create *IdentityAdoptionDecisionCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateNewValues() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldCreatedAt)
+ }
+ if _, exists := b.mutation.DecidedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldDecidedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertBulk) Ignore() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *IdentityAdoptionDecisionUpsertBulk) DoNothing() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreateBulk.OnConflict
+// documentation for more info.
+func (u *IdentityAdoptionDecisionUpsertBulk) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&IdentityAdoptionDecisionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetPendingAuthSessionID(v)
+ })
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdatePendingAuthSessionID()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateIdentityID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) ClearIdentityID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.ClearIdentityID()
+ })
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptDisplayName(v)
+ })
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptDisplayName()
+ })
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptAvatar(v)
+ })
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptAvatar()
+ })
+}
+
+// Exec executes the query.
+func (u *IdentityAdoptionDecisionUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdentityAdoptionDecisionCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for IdentityAdoptionDecisionCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/identityadoptiondecision_delete.go b/backend/ent/identityadoptiondecision_delete.go
new file mode 100644
index 00000000..ef3d328d
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionDelete is the builder for deleting a IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionDelete struct {
+ config
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder.
+func (_d *IdentityAdoptionDecisionDelete) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *IdentityAdoptionDecisionDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *IdentityAdoptionDecisionDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *IdentityAdoptionDecisionDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// IdentityAdoptionDecisionDeleteOne is the builder for deleting a single IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionDeleteOne struct {
+ _d *IdentityAdoptionDecisionDelete
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder.
+func (_d *IdentityAdoptionDecisionDeleteOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *IdentityAdoptionDecisionDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{identityadoptiondecision.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *IdentityAdoptionDecisionDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/identityadoptiondecision_query.go b/backend/ent/identityadoptiondecision_query.go
new file mode 100644
index 00000000..4082d8ee
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_query.go
@@ -0,0 +1,721 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionQuery is the builder for querying IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionQuery struct {
+ config
+ ctx *QueryContext
+ order []identityadoptiondecision.OrderOption
+ inters []Interceptor
+ predicates []predicate.IdentityAdoptionDecision
+ withPendingAuthSession *PendingAuthSessionQuery
+ withIdentity *AuthIdentityQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the IdentityAdoptionDecisionQuery builder.
+func (_q *IdentityAdoptionDecisionQuery) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *IdentityAdoptionDecisionQuery) Limit(limit int) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *IdentityAdoptionDecisionQuery) Offset(offset int) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *IdentityAdoptionDecisionQuery) Unique(unique bool) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *IdentityAdoptionDecisionQuery) Order(o ...identityadoptiondecision.OrderOption) *IdentityAdoptionDecisionQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryPendingAuthSession chains the current query on the "pending_auth_session" edge.
+func (_q *IdentityAdoptionDecisionQuery) QueryPendingAuthSession() *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryIdentity chains the current query on the "identity" edge.
+func (_q *IdentityAdoptionDecisionQuery) QueryIdentity() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first IdentityAdoptionDecision entity from the query.
+// Returns a *NotFoundError when no IdentityAdoptionDecision was found.
+func (_q *IdentityAdoptionDecisionQuery) First(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{identityadoptiondecision.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) FirstX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first IdentityAdoptionDecision ID from the query.
+// Returns a *NotFoundError when no IdentityAdoptionDecision ID was found.
+func (_q *IdentityAdoptionDecisionQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single IdentityAdoptionDecision entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one IdentityAdoptionDecision entity is found.
+// Returns a *NotFoundError when no IdentityAdoptionDecision entities are found.
+func (_q *IdentityAdoptionDecisionQuery) Only(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{identityadoptiondecision.Label}
+ default:
+ return nil, &NotSingularError{identityadoptiondecision.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) OnlyX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only IdentityAdoptionDecision ID in the query.
+// Returns a *NotSingularError when more than one IdentityAdoptionDecision ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *IdentityAdoptionDecisionQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{identityadoptiondecision.Label}
+ default:
+ err = &NotSingularError{identityadoptiondecision.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of IdentityAdoptionDecisions.
+func (_q *IdentityAdoptionDecisionQuery) All(ctx context.Context) ([]*IdentityAdoptionDecision, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*IdentityAdoptionDecision, *IdentityAdoptionDecisionQuery]()
+ return withInterceptors[[]*IdentityAdoptionDecision](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) AllX(ctx context.Context) []*IdentityAdoptionDecision {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of IdentityAdoptionDecision IDs.
+func (_q *IdentityAdoptionDecisionQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(identityadoptiondecision.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *IdentityAdoptionDecisionQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*IdentityAdoptionDecisionQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *IdentityAdoptionDecisionQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the IdentityAdoptionDecisionQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *IdentityAdoptionDecisionQuery) Clone() *IdentityAdoptionDecisionQuery {
+ if _q == nil {
+ return nil
+ }
+ return &IdentityAdoptionDecisionQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]identityadoptiondecision.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.IdentityAdoptionDecision{}, _q.predicates...),
+ withPendingAuthSession: _q.withPendingAuthSession.Clone(),
+ withIdentity: _q.withIdentity.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithPendingAuthSession tells the query-builder to eager-load the nodes that are connected to
+// the "pending_auth_session" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *IdentityAdoptionDecisionQuery) WithPendingAuthSession(opts ...func(*PendingAuthSessionQuery)) *IdentityAdoptionDecisionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withPendingAuthSession = query
+ return _q
+}
+
+// WithIdentity tells the query-builder to eager-load the nodes that are connected to
+// the "identity" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *IdentityAdoptionDecisionQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *IdentityAdoptionDecisionQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withIdentity = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.IdentityAdoptionDecision.Query().
+// GroupBy(identityadoptiondecision.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *IdentityAdoptionDecisionQuery) GroupBy(field string, fields ...string) *IdentityAdoptionDecisionGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &IdentityAdoptionDecisionGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = identityadoptiondecision.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.IdentityAdoptionDecision.Query().
+// Select(identityadoptiondecision.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *IdentityAdoptionDecisionQuery) Select(fields ...string) *IdentityAdoptionDecisionSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &IdentityAdoptionDecisionSelect{IdentityAdoptionDecisionQuery: _q}
+ sbuild.label = identityadoptiondecision.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a IdentityAdoptionDecisionSelect configured with the given aggregations.
+func (_q *IdentityAdoptionDecisionQuery) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *IdentityAdoptionDecisionQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !identityadoptiondecision.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdentityAdoptionDecision, error) {
+ var (
+ nodes = []*IdentityAdoptionDecision{}
+ _spec = _q.querySpec()
+ loadedTypes = [2]bool{
+ _q.withPendingAuthSession != nil,
+ _q.withIdentity != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*IdentityAdoptionDecision).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &IdentityAdoptionDecision{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withPendingAuthSession; query != nil {
+ if err := _q.loadPendingAuthSession(ctx, query, nodes, nil,
+ func(n *IdentityAdoptionDecision, e *PendingAuthSession) { n.Edges.PendingAuthSession = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withIdentity; query != nil {
+ if err := _q.loadIdentity(ctx, query, nodes, nil,
+ func(n *IdentityAdoptionDecision, e *AuthIdentity) { n.Edges.Identity = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) loadPendingAuthSession(ctx context.Context, query *PendingAuthSessionQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *PendingAuthSession)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*IdentityAdoptionDecision)
+ for i := range nodes {
+ fk := nodes[i].PendingAuthSessionID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(pendingauthsession.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "pending_auth_session_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *IdentityAdoptionDecisionQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *AuthIdentity)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*IdentityAdoptionDecision)
+ for i := range nodes {
+ if nodes[i].IdentityID == nil {
+ continue
+ }
+ fk := *nodes[i].IdentityID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(authidentity.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *IdentityAdoptionDecisionQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID)
+ for i := range fields {
+ if fields[i] != identityadoptiondecision.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withPendingAuthSession != nil {
+ _spec.Node.AddColumnOnce(identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ if _q.withIdentity != nil {
+ _spec.Node.AddColumnOnce(identityadoptiondecision.FieldIdentityID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(identityadoptiondecision.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = identityadoptiondecision.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *IdentityAdoptionDecisionQuery) ForUpdate(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *IdentityAdoptionDecisionQuery) ForShare(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// IdentityAdoptionDecisionGroupBy is the group-by builder for IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionGroupBy struct {
+ selector
+ build *IdentityAdoptionDecisionQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *IdentityAdoptionDecisionGroupBy) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *IdentityAdoptionDecisionGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *IdentityAdoptionDecisionGroupBy) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// IdentityAdoptionDecisionSelect is the builder for selecting fields of IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionSelect struct {
+ *IdentityAdoptionDecisionQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *IdentityAdoptionDecisionSelect) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *IdentityAdoptionDecisionSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionSelect](ctx, _s.IdentityAdoptionDecisionQuery, _s, _s.inters, v)
+}
+
+func (_s *IdentityAdoptionDecisionSelect) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/identityadoptiondecision_update.go b/backend/ent/identityadoptiondecision_update.go
new file mode 100644
index 00000000..0ca21d27
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_update.go
@@ -0,0 +1,532 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionUpdate is the builder for updating IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionUpdate struct {
+ config
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder.
+func (_u *IdentityAdoptionDecisionUpdate) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetPendingAuthSessionID(v)
+ return _u
+}
+
+// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetPendingAuthSessionID(*v)
+ }
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) ClearIdentityID() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearIdentityID()
+ return _u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetAdoptDisplayName(v)
+ return _u
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetAdoptDisplayName(*v)
+ }
+ return _u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetAdoptAvatar(v)
+ return _u
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetAdoptAvatar(*v)
+ }
+ return _u
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdate {
+ return _u.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdate {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_u *IdentityAdoptionDecisionUpdate) Mutation() *IdentityAdoptionDecisionMutation {
+ return _u.mutation
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdate) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearPendingAuthSession()
+ return _u
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdate) ClearIdentity() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *IdentityAdoptionDecisionUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *IdentityAdoptionDecisionUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *IdentityAdoptionDecisionUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *IdentityAdoptionDecisionUpdate) check() error {
+ if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`)
+ }
+ return nil
+}
+
+func (_u *IdentityAdoptionDecisionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ }
+ if _u.mutation.PendingAuthSessionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// IdentityAdoptionDecisionUpdateOne is the builder for updating a single IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetPendingAuthSessionID(v)
+ return _u
+}
+
+// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetPendingAuthSessionID(*v)
+ }
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentityID() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearIdentityID()
+ return _u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetAdoptDisplayName(v)
+ return _u
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetAdoptDisplayName(*v)
+ }
+ return _u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetAdoptAvatar(v)
+ return _u
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetAdoptAvatar(*v)
+ }
+ return _u
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdateOne {
+ return _u.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdateOne {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) Mutation() *IdentityAdoptionDecisionMutation {
+ return _u.mutation
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearPendingAuthSession()
+ return _u
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentity() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *IdentityAdoptionDecisionUpdateOne) Select(field string, fields ...string) *IdentityAdoptionDecisionUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated IdentityAdoptionDecision entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) Save(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdateOne) SaveX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *IdentityAdoptionDecisionUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) check() error {
+ if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`)
+ }
+ return nil
+}
+
+func (_u *IdentityAdoptionDecisionUpdateOne) sqlSave(ctx context.Context) (_node *IdentityAdoptionDecision, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdentityAdoptionDecision.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID)
+ for _, f := range fields {
+ if !identityadoptiondecision.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != identityadoptiondecision.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ }
+ if _u.mutation.PendingAuthSessionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &IdentityAdoptionDecision{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go
index 8d8320bb..95b68e09 100644
--- a/backend/ent/intercept/intercept.go
+++ b/backend/ent/intercept/intercept.go
@@ -13,12 +13,20 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -228,6 +236,168 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
}
+// The AuthIdentityFunc type is an adapter to allow the use of ordinary function as a Querier.
+type AuthIdentityFunc func(context.Context, *ent.AuthIdentityQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f AuthIdentityFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.AuthIdentityQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q)
+}
+
+// The TraverseAuthIdentity type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseAuthIdentity func(context.Context, *ent.AuthIdentityQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseAuthIdentity) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseAuthIdentity) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.AuthIdentityQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q)
+}
+
+// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary function as a Querier.
+type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f AuthIdentityChannelFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.AuthIdentityChannelQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q)
+}
+
+// The TraverseAuthIdentityChannel type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseAuthIdentityChannel func(context.Context, *ent.AuthIdentityChannelQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseAuthIdentityChannel) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseAuthIdentityChannel) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.AuthIdentityChannelQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q)
+}
+
+// The ChannelMonitorFunc type is an adapter to allow the use of ordinary function as a Querier.
+type ChannelMonitorFunc func(context.Context, *ent.ChannelMonitorQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f ChannelMonitorFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.ChannelMonitorQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorQuery", q)
+}
+
+// The TraverseChannelMonitor type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseChannelMonitor func(context.Context, *ent.ChannelMonitorQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseChannelMonitor) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseChannelMonitor) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.ChannelMonitorQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorQuery", q)
+}
+
+// The ChannelMonitorDailyRollupFunc type is an adapter to allow the use of ordinary function as a Querier.
+type ChannelMonitorDailyRollupFunc func(context.Context, *ent.ChannelMonitorDailyRollupQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f ChannelMonitorDailyRollupFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.ChannelMonitorDailyRollupQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorDailyRollupQuery", q)
+}
+
+// The TraverseChannelMonitorDailyRollup type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseChannelMonitorDailyRollup func(context.Context, *ent.ChannelMonitorDailyRollupQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseChannelMonitorDailyRollup) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseChannelMonitorDailyRollup) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.ChannelMonitorDailyRollupQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorDailyRollupQuery", q)
+}
+
+// The ChannelMonitorHistoryFunc type is an adapter to allow the use of ordinary function as a Querier.
+type ChannelMonitorHistoryFunc func(context.Context, *ent.ChannelMonitorHistoryQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f ChannelMonitorHistoryFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.ChannelMonitorHistoryQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorHistoryQuery", q)
+}
+
+// The TraverseChannelMonitorHistory type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseChannelMonitorHistory func(context.Context, *ent.ChannelMonitorHistoryQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseChannelMonitorHistory) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseChannelMonitorHistory) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.ChannelMonitorHistoryQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorHistoryQuery", q)
+}
+
+// The ChannelMonitorRequestTemplateFunc type is an adapter to allow the use of ordinary function as a Querier.
+type ChannelMonitorRequestTemplateFunc func(context.Context, *ent.ChannelMonitorRequestTemplateQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f ChannelMonitorRequestTemplateFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.ChannelMonitorRequestTemplateQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorRequestTemplateQuery", q)
+}
+
+// The TraverseChannelMonitorRequestTemplate type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseChannelMonitorRequestTemplate func(context.Context, *ent.ChannelMonitorRequestTemplateQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseChannelMonitorRequestTemplate) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseChannelMonitorRequestTemplate) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.ChannelMonitorRequestTemplateQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorRequestTemplateQuery", q)
+}
+
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error)
@@ -309,6 +479,33 @@ func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) er
return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q)
}
+// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary function as a Querier.
+type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f IdentityAdoptionDecisionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q)
+}
+
+// The TraverseIdentityAdoptionDecision type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseIdentityAdoptionDecision func(context.Context, *ent.IdentityAdoptionDecisionQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseIdentityAdoptionDecision) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseIdentityAdoptionDecision) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q)
+}
+
// The PaymentAuditLogFunc type is an adapter to allow the use of ordinary function as a Querier.
type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogQuery) (ent.Value, error)
@@ -390,6 +587,33 @@ func (f TraversePaymentProviderInstance) Traverse(ctx context.Context, q ent.Que
return fmt.Errorf("unexpected query type %T. expect *ent.PaymentProviderInstanceQuery", q)
}
+// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary function as a Querier.
+type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f PendingAuthSessionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.PendingAuthSessionQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q)
+}
+
+// The TraversePendingAuthSession type is an adapter to allow the use of ordinary function as Traverser.
+type TraversePendingAuthSession func(context.Context, *ent.PendingAuthSessionQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraversePendingAuthSession) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraversePendingAuthSession) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.PendingAuthSessionQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q)
+}
+
// The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier.
type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error)
@@ -808,18 +1032,34 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil
case *ent.AnnouncementReadQuery:
return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil
+ case *ent.AuthIdentityQuery:
+ return &query[*ent.AuthIdentityQuery, predicate.AuthIdentity, authidentity.OrderOption]{typ: ent.TypeAuthIdentity, tq: q}, nil
+ case *ent.AuthIdentityChannelQuery:
+ return &query[*ent.AuthIdentityChannelQuery, predicate.AuthIdentityChannel, authidentitychannel.OrderOption]{typ: ent.TypeAuthIdentityChannel, tq: q}, nil
+ case *ent.ChannelMonitorQuery:
+ return &query[*ent.ChannelMonitorQuery, predicate.ChannelMonitor, channelmonitor.OrderOption]{typ: ent.TypeChannelMonitor, tq: q}, nil
+ case *ent.ChannelMonitorDailyRollupQuery:
+ return &query[*ent.ChannelMonitorDailyRollupQuery, predicate.ChannelMonitorDailyRollup, channelmonitordailyrollup.OrderOption]{typ: ent.TypeChannelMonitorDailyRollup, tq: q}, nil
+ case *ent.ChannelMonitorHistoryQuery:
+ return &query[*ent.ChannelMonitorHistoryQuery, predicate.ChannelMonitorHistory, channelmonitorhistory.OrderOption]{typ: ent.TypeChannelMonitorHistory, tq: q}, nil
+ case *ent.ChannelMonitorRequestTemplateQuery:
+ return &query[*ent.ChannelMonitorRequestTemplateQuery, predicate.ChannelMonitorRequestTemplate, channelmonitorrequesttemplate.OrderOption]{typ: ent.TypeChannelMonitorRequestTemplate, tq: q}, nil
case *ent.ErrorPassthroughRuleQuery:
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
case *ent.GroupQuery:
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
case *ent.IdempotencyRecordQuery:
return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil
+ case *ent.IdentityAdoptionDecisionQuery:
+ return &query[*ent.IdentityAdoptionDecisionQuery, predicate.IdentityAdoptionDecision, identityadoptiondecision.OrderOption]{typ: ent.TypeIdentityAdoptionDecision, tq: q}, nil
case *ent.PaymentAuditLogQuery:
return &query[*ent.PaymentAuditLogQuery, predicate.PaymentAuditLog, paymentauditlog.OrderOption]{typ: ent.TypePaymentAuditLog, tq: q}, nil
case *ent.PaymentOrderQuery:
return &query[*ent.PaymentOrderQuery, predicate.PaymentOrder, paymentorder.OrderOption]{typ: ent.TypePaymentOrder, tq: q}, nil
case *ent.PaymentProviderInstanceQuery:
return &query[*ent.PaymentProviderInstanceQuery, predicate.PaymentProviderInstance, paymentproviderinstance.OrderOption]{typ: ent.TypePaymentProviderInstance, tq: q}, nil
+ case *ent.PendingAuthSessionQuery:
+ return &query[*ent.PendingAuthSessionQuery, predicate.PendingAuthSession, pendingauthsession.OrderOption]{typ: ent.TypePendingAuthSession, tq: q}, nil
case *ent.PromoCodeQuery:
return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil
case *ent.PromoCodeUsageQuery:
diff --git a/backend/ent/migrate/auth_identity_fk_ondelete_test.go b/backend/ent/migrate/auth_identity_fk_ondelete_test.go
new file mode 100644
index 00000000..0e37025a
--- /dev/null
+++ b/backend/ent/migrate/auth_identity_fk_ondelete_test.go
@@ -0,0 +1,73 @@
+package migrate
+
+import (
+ "testing"
+
+ "entgo.io/ent/dialect/entsql"
+ entschema "entgo.io/ent/dialect/sql/schema"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityFoundationForeignKeyOnDeleteActions(t *testing.T) {
+ require.Equal(
+ t,
+ entschema.Cascade,
+ findForeignKeyBySymbol(t, AuthIdentitiesTable, "auth_identities_users_auth_identities").OnDelete,
+ )
+ require.Equal(
+ t,
+ entschema.Cascade,
+ findForeignKeyBySymbol(t, AuthIdentityChannelsTable, "auth_identity_channels_auth_identities_channels").OnDelete,
+ )
+ require.Equal(
+ t,
+ entschema.Cascade,
+ findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_pending_auth_sessions_adoption_decision").OnDelete,
+ )
+
+ require.Equal(
+ t,
+ entschema.SetNull,
+ findForeignKeyBySymbol(t, PendingAuthSessionsTable, "pending_auth_sessions_users_pending_auth_sessions").OnDelete,
+ )
+ require.Equal(
+ t,
+ entschema.SetNull,
+ findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_auth_identities_adoption_decisions").OnDelete,
+ )
+}
+
+func TestPaymentOrdersOutTradeNoPartialUniqueIndex(t *testing.T) {
+ idx := findIndexByName(t, PaymentOrdersTable, "paymentorder_out_trade_no")
+ require.True(t, idx.Unique)
+ require.Len(t, idx.Columns, 1)
+ require.Equal(t, "out_trade_no", idx.Columns[0].Name)
+ require.NotNil(t, idx.Annotation)
+ require.Equal(t, (&entsql.IndexAnnotation{Where: "out_trade_no <> ''"}).Where, idx.Annotation.Where)
+}
+
+func findForeignKeyBySymbol(t *testing.T, table *entschema.Table, symbol string) *entschema.ForeignKey {
+ t.Helper()
+
+ for _, fk := range table.ForeignKeys {
+ if fk.Symbol == symbol {
+ return fk
+ }
+ }
+
+ require.Failf(t, "missing foreign key", "table %s should include foreign key %s", table.Name, symbol)
+ return nil
+}
+
+func findIndexByName(t *testing.T, table *entschema.Table, name string) *entschema.Index {
+ t.Helper()
+
+ for _, idx := range table.Indexes {
+ if idx.Name == name {
+ return idx
+ }
+ }
+
+ require.Failf(t, "missing index", "table %s should include index %s", table.Name, name)
+ return nil
+}
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 68bdbf55..178ae170 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -338,6 +338,252 @@ var (
},
},
}
+ // AuthIdentitiesColumns holds the columns for the "auth_identities" table.
+ AuthIdentitiesColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "issuer", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "user_id", Type: field.TypeInt64},
+ }
+ // AuthIdentitiesTable holds the schema information for the "auth_identities" table.
+ AuthIdentitiesTable = &schema.Table{
+ Name: "auth_identities",
+ Columns: AuthIdentitiesColumns,
+ PrimaryKey: []*schema.Column{AuthIdentitiesColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "auth_identities_users_auth_identities",
+ Columns: []*schema.Column{AuthIdentitiesColumns[9]},
+ RefColumns: []*schema.Column{UsersColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "authidentity_provider_type_provider_key_provider_subject",
+ Unique: true,
+ Columns: []*schema.Column{AuthIdentitiesColumns[3], AuthIdentitiesColumns[4], AuthIdentitiesColumns[5]},
+ },
+ {
+ Name: "authidentity_user_id",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentitiesColumns[9]},
+ },
+ {
+ Name: "authidentity_user_id_provider_type",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentitiesColumns[9], AuthIdentitiesColumns[3]},
+ },
+ },
+ }
+ // AuthIdentityChannelsColumns holds the columns for the "auth_identity_channels" table.
+ AuthIdentityChannelsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "channel", Type: field.TypeString, Size: 20},
+ {Name: "channel_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "channel_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "identity_id", Type: field.TypeInt64},
+ }
+ // AuthIdentityChannelsTable holds the schema information for the "auth_identity_channels" table.
+ AuthIdentityChannelsTable = &schema.Table{
+ Name: "auth_identity_channels",
+ Columns: AuthIdentityChannelsColumns,
+ PrimaryKey: []*schema.Column{AuthIdentityChannelsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "auth_identity_channels_auth_identities_channels",
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
+ RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "authidentitychannel_provider_type_provider_key_channel_channel_app_id_channel_subject",
+ Unique: true,
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[3], AuthIdentityChannelsColumns[4], AuthIdentityChannelsColumns[5], AuthIdentityChannelsColumns[6], AuthIdentityChannelsColumns[7]},
+ },
+ {
+ Name: "authidentitychannel_identity_id",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
+ },
+ },
+ }
+ // ChannelMonitorsColumns holds the columns for the "channel_monitors" table.
+ ChannelMonitorsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "name", Type: field.TypeString, Size: 100},
+ {Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}},
+ {Name: "endpoint", Type: field.TypeString, Size: 500},
+ {Name: "api_key_encrypted", Type: field.TypeString},
+ {Name: "primary_model", Type: field.TypeString, Size: 200},
+ {Name: "extra_models", Type: field.TypeJSON},
+ {Name: "group_name", Type: field.TypeString, Nullable: true, Size: 100, Default: ""},
+ {Name: "enabled", Type: field.TypeBool, Default: true},
+ {Name: "interval_seconds", Type: field.TypeInt},
+ {Name: "last_checked_at", Type: field.TypeTime, Nullable: true},
+ {Name: "created_by", Type: field.TypeInt64},
+ {Name: "extra_headers", Type: field.TypeJSON},
+ {Name: "body_override_mode", Type: field.TypeString, Size: 10, Default: "off"},
+ {Name: "body_override", Type: field.TypeJSON, Nullable: true},
+ {Name: "template_id", Type: field.TypeInt64, Nullable: true},
+ }
+ // ChannelMonitorsTable holds the schema information for the "channel_monitors" table.
+ ChannelMonitorsTable = &schema.Table{
+ Name: "channel_monitors",
+ Columns: ChannelMonitorsColumns,
+ PrimaryKey: []*schema.Column{ChannelMonitorsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "channel_monitors_channel_monitor_request_templates_request_template",
+ Columns: []*schema.Column{ChannelMonitorsColumns[17]},
+ RefColumns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "channelmonitor_enabled_last_checked_at",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorsColumns[10], ChannelMonitorsColumns[12]},
+ },
+ {
+ Name: "channelmonitor_provider",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorsColumns[4]},
+ },
+ {
+ Name: "channelmonitor_group_name",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorsColumns[9]},
+ },
+ {
+ Name: "channelmonitor_template_id",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorsColumns[17]},
+ },
+ },
+ }
+ // ChannelMonitorDailyRollupsColumns holds the columns for the "channel_monitor_daily_rollups" table.
+ ChannelMonitorDailyRollupsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "model", Type: field.TypeString, Size: 200},
+ {Name: "bucket_date", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "date"}},
+ {Name: "total_checks", Type: field.TypeInt, Default: 0},
+ {Name: "ok_count", Type: field.TypeInt, Default: 0},
+ {Name: "operational_count", Type: field.TypeInt, Default: 0},
+ {Name: "degraded_count", Type: field.TypeInt, Default: 0},
+ {Name: "failed_count", Type: field.TypeInt, Default: 0},
+ {Name: "error_count", Type: field.TypeInt, Default: 0},
+ {Name: "sum_latency_ms", Type: field.TypeInt64, Default: 0},
+ {Name: "count_latency", Type: field.TypeInt, Default: 0},
+ {Name: "sum_ping_latency_ms", Type: field.TypeInt64, Default: 0},
+ {Name: "count_ping_latency", Type: field.TypeInt, Default: 0},
+ {Name: "computed_at", Type: field.TypeTime},
+ {Name: "monitor_id", Type: field.TypeInt64},
+ }
+ // ChannelMonitorDailyRollupsTable holds the schema information for the "channel_monitor_daily_rollups" table.
+ ChannelMonitorDailyRollupsTable = &schema.Table{
+ Name: "channel_monitor_daily_rollups",
+ Columns: ChannelMonitorDailyRollupsColumns,
+ PrimaryKey: []*schema.Column{ChannelMonitorDailyRollupsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "channel_monitor_daily_rollups_channel_monitors_daily_rollups",
+ Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[14]},
+ RefColumns: []*schema.Column{ChannelMonitorsColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "channelmonitordailyrollup_monitor_id_model_bucket_date",
+ Unique: true,
+ Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[14], ChannelMonitorDailyRollupsColumns[1], ChannelMonitorDailyRollupsColumns[2]},
+ },
+ {
+ Name: "channelmonitordailyrollup_bucket_date",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[2]},
+ },
+ },
+ }
+ // ChannelMonitorHistoriesColumns holds the columns for the "channel_monitor_histories" table.
+ ChannelMonitorHistoriesColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "model", Type: field.TypeString, Size: 200},
+ {Name: "status", Type: field.TypeEnum, Enums: []string{"operational", "degraded", "failed", "error"}},
+ {Name: "latency_ms", Type: field.TypeInt, Nullable: true},
+ {Name: "ping_latency_ms", Type: field.TypeInt, Nullable: true},
+ {Name: "message", Type: field.TypeString, Nullable: true, Size: 500, Default: ""},
+ {Name: "checked_at", Type: field.TypeTime},
+ {Name: "monitor_id", Type: field.TypeInt64},
+ }
+ // ChannelMonitorHistoriesTable holds the schema information for the "channel_monitor_histories" table.
+ ChannelMonitorHistoriesTable = &schema.Table{
+ Name: "channel_monitor_histories",
+ Columns: ChannelMonitorHistoriesColumns,
+ PrimaryKey: []*schema.Column{ChannelMonitorHistoriesColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "channel_monitor_histories_channel_monitors_history",
+ Columns: []*schema.Column{ChannelMonitorHistoriesColumns[7]},
+ RefColumns: []*schema.Column{ChannelMonitorsColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "channelmonitorhistory_monitor_id_model_checked_at",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorHistoriesColumns[7], ChannelMonitorHistoriesColumns[1], ChannelMonitorHistoriesColumns[6]},
+ },
+ {
+ Name: "channelmonitorhistory_checked_at",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorHistoriesColumns[6]},
+ },
+ },
+ }
+ // ChannelMonitorRequestTemplatesColumns holds the columns for the "channel_monitor_request_templates" table.
+ ChannelMonitorRequestTemplatesColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "name", Type: field.TypeString, Size: 100},
+ {Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}},
+ {Name: "description", Type: field.TypeString, Nullable: true, Size: 500, Default: ""},
+ {Name: "extra_headers", Type: field.TypeJSON},
+ {Name: "body_override_mode", Type: field.TypeString, Size: 10, Default: "off"},
+ {Name: "body_override", Type: field.TypeJSON, Nullable: true},
+ }
+ // ChannelMonitorRequestTemplatesTable holds the schema information for the "channel_monitor_request_templates" table.
+ ChannelMonitorRequestTemplatesTable = &schema.Table{
+ Name: "channel_monitor_request_templates",
+ Columns: ChannelMonitorRequestTemplatesColumns,
+ PrimaryKey: []*schema.Column{ChannelMonitorRequestTemplatesColumns[0]},
+ Indexes: []*schema.Index{
+ {
+ Name: "channelmonitorrequesttemplate_provider_name",
+ Unique: true,
+ Columns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[4], ChannelMonitorRequestTemplatesColumns[3]},
+ },
+ },
+ }
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
ErrorPassthroughRulesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -408,6 +654,7 @@ var (
{Name: "require_privacy_set", Type: field.TypeBool, Default: false},
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
{Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "rpm_limit", Type: field.TypeInt, Default: 0},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{
@@ -485,6 +732,49 @@ var (
},
},
}
+ // IdentityAdoptionDecisionsColumns holds the columns for the "identity_adoption_decisions" table.
+ IdentityAdoptionDecisionsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "adopt_display_name", Type: field.TypeBool, Default: false},
+ {Name: "adopt_avatar", Type: field.TypeBool, Default: false},
+ {Name: "decided_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "identity_id", Type: field.TypeInt64, Nullable: true},
+ {Name: "pending_auth_session_id", Type: field.TypeInt64, Unique: true},
+ }
+ // IdentityAdoptionDecisionsTable holds the schema information for the "identity_adoption_decisions" table.
+ IdentityAdoptionDecisionsTable = &schema.Table{
+ Name: "identity_adoption_decisions",
+ Columns: IdentityAdoptionDecisionsColumns,
+ PrimaryKey: []*schema.Column{IdentityAdoptionDecisionsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "identity_adoption_decisions_auth_identities_adoption_decisions",
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]},
+ RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ {
+ Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision",
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
+ RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "identityadoptiondecision_pending_auth_session_id",
+ Unique: true,
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
+ },
+ {
+ Name: "identityadoptiondecision_identity_id",
+ Unique: false,
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]},
+ },
+ },
+ }
// PaymentAuditLogsColumns holds the columns for the "payment_audit_logs" table.
PaymentAuditLogsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -528,6 +818,8 @@ var (
{Name: "subscription_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "subscription_days", Type: field.TypeInt, Nullable: true},
{Name: "provider_instance_id", Type: field.TypeString, Nullable: true, Size: 64},
+ {Name: "provider_key", Type: field.TypeString, Nullable: true, Size: 30},
+ {Name: "provider_snapshot", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "status", Type: field.TypeString, Size: 30, Default: "PENDING"},
{Name: "refund_amount", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,2)"}},
{Name: "refund_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
@@ -556,7 +848,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "payment_orders_users_payment_orders",
- Columns: []*schema.Column{PaymentOrdersColumns[37]},
+ Columns: []*schema.Column{PaymentOrdersColumns[39]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
@@ -564,38 +856,41 @@ var (
Indexes: []*schema.Index{
{
Name: "paymentorder_out_trade_no",
- Unique: false,
+ Unique: true,
Columns: []*schema.Column{PaymentOrdersColumns[8]},
+ Annotation: &entsql.IndexAnnotation{
+ Where: "out_trade_no <> ''",
+ },
},
{
Name: "paymentorder_user_id",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[37]},
+ Columns: []*schema.Column{PaymentOrdersColumns[39]},
},
{
Name: "paymentorder_status",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[19]},
+ Columns: []*schema.Column{PaymentOrdersColumns[21]},
},
{
Name: "paymentorder_expires_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[27]},
+ Columns: []*schema.Column{PaymentOrdersColumns[29]},
},
{
Name: "paymentorder_created_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[35]},
+ Columns: []*schema.Column{PaymentOrdersColumns[37]},
},
{
Name: "paymentorder_paid_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[28]},
+ Columns: []*schema.Column{PaymentOrdersColumns[30]},
},
{
Name: "paymentorder_payment_type_paid_at",
Unique: false,
- Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[28]},
+ Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[30]},
},
{
Name: "paymentorder_order_type",
@@ -638,6 +933,72 @@ var (
},
},
}
+ // PendingAuthSessionsColumns holds the columns for the "pending_auth_sessions" table.
+ PendingAuthSessionsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "session_token", Type: field.TypeString, Size: 255},
+ {Name: "intent", Type: field.TypeString, Size: 40},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "redirect_to", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "resolved_email", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "registration_password_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "upstream_identity_claims", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "local_flow_state", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "browser_session_key", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "completion_code_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "completion_code_expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "email_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "password_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "totp_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "consumed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "target_user_id", Type: field.TypeInt64, Nullable: true},
+ }
+ // PendingAuthSessionsTable holds the schema information for the "pending_auth_sessions" table.
+ PendingAuthSessionsTable = &schema.Table{
+ Name: "pending_auth_sessions",
+ Columns: PendingAuthSessionsColumns,
+ PrimaryKey: []*schema.Column{PendingAuthSessionsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "pending_auth_sessions_users_pending_auth_sessions",
+ Columns: []*schema.Column{PendingAuthSessionsColumns[21]},
+ RefColumns: []*schema.Column{UsersColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "pendingauthsession_session_token",
+ Unique: true,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[3]},
+ },
+ {
+ Name: "pendingauthsession_target_user_id",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[21]},
+ },
+ {
+ Name: "pendingauthsession_expires_at",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[19]},
+ },
+ {
+ Name: "pendingauthsession_provider_type_provider_key_provider_subject",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[5], PendingAuthSessionsColumns[6], PendingAuthSessionsColumns[7]},
+ },
+ {
+ Name: "pendingauthsession_completion_code_hash",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[14]},
+ },
+ },
+ }
// PromoCodesColumns holds the columns for the "promo_codes" table.
PromoCodesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -1079,11 +1440,15 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
+ {Name: "signup_source", Type: field.TypeString, Default: "email"},
+ {Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
{Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"},
{Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}},
{Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
+ {Name: "rpm_limit", Type: field.TypeInt, Default: 0},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
@@ -1318,12 +1683,20 @@ var (
AccountGroupsTable,
AnnouncementsTable,
AnnouncementReadsTable,
+ AuthIdentitiesTable,
+ AuthIdentityChannelsTable,
+ ChannelMonitorsTable,
+ ChannelMonitorDailyRollupsTable,
+ ChannelMonitorHistoriesTable,
+ ChannelMonitorRequestTemplatesTable,
ErrorPassthroughRulesTable,
GroupsTable,
IdempotencyRecordsTable,
+ IdentityAdoptionDecisionsTable,
PaymentAuditLogsTable,
PaymentOrdersTable,
PaymentProviderInstancesTable,
+ PendingAuthSessionsTable,
PromoCodesTable,
PromoCodeUsagesTable,
ProxiesTable,
@@ -1365,6 +1738,29 @@ func init() {
AnnouncementReadsTable.Annotation = &entsql.Annotation{
Table: "announcement_reads",
}
+ AuthIdentitiesTable.ForeignKeys[0].RefTable = UsersTable
+ AuthIdentitiesTable.Annotation = &entsql.Annotation{
+ Table: "auth_identities",
+ }
+ AuthIdentityChannelsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable
+ AuthIdentityChannelsTable.Annotation = &entsql.Annotation{
+ Table: "auth_identity_channels",
+ }
+ ChannelMonitorsTable.ForeignKeys[0].RefTable = ChannelMonitorRequestTemplatesTable
+ ChannelMonitorsTable.Annotation = &entsql.Annotation{
+ Table: "channel_monitors",
+ }
+ ChannelMonitorDailyRollupsTable.ForeignKeys[0].RefTable = ChannelMonitorsTable
+ ChannelMonitorDailyRollupsTable.Annotation = &entsql.Annotation{
+ Table: "channel_monitor_daily_rollups",
+ }
+ ChannelMonitorHistoriesTable.ForeignKeys[0].RefTable = ChannelMonitorsTable
+ ChannelMonitorHistoriesTable.Annotation = &entsql.Annotation{
+ Table: "channel_monitor_histories",
+ }
+ ChannelMonitorRequestTemplatesTable.Annotation = &entsql.Annotation{
+ Table: "channel_monitor_request_templates",
+ }
ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{
Table: "error_passthrough_rules",
}
@@ -1374,6 +1770,11 @@ func init() {
IdempotencyRecordsTable.Annotation = &entsql.Annotation{
Table: "idempotency_records",
}
+ IdentityAdoptionDecisionsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable
+ IdentityAdoptionDecisionsTable.ForeignKeys[1].RefTable = PendingAuthSessionsTable
+ IdentityAdoptionDecisionsTable.Annotation = &entsql.Annotation{
+ Table: "identity_adoption_decisions",
+ }
PaymentAuditLogsTable.Annotation = &entsql.Annotation{
Table: "payment_audit_logs",
}
@@ -1384,6 +1785,10 @@ func init() {
PaymentProviderInstancesTable.Annotation = &entsql.Annotation{
Table: "payment_provider_instances",
}
+ PendingAuthSessionsTable.ForeignKeys[0].RefTable = UsersTable
+ PendingAuthSessionsTable.Annotation = &entsql.Annotation{
+ Table: "pending_auth_sessions",
+ }
PromoCodesTable.Annotation = &entsql.Annotation{
Table: "promo_codes",
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 524ccb92..d616e4ae 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -17,12 +17,20 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -51,32 +59,40 @@ const (
OpUpdateOne = ent.OpUpdateOne
// Node types.
- TypeAPIKey = "APIKey"
- TypeAccount = "Account"
- TypeAccountGroup = "AccountGroup"
- TypeAnnouncement = "Announcement"
- TypeAnnouncementRead = "AnnouncementRead"
- TypeErrorPassthroughRule = "ErrorPassthroughRule"
- TypeGroup = "Group"
- TypeIdempotencyRecord = "IdempotencyRecord"
- TypePaymentAuditLog = "PaymentAuditLog"
- TypePaymentOrder = "PaymentOrder"
- TypePaymentProviderInstance = "PaymentProviderInstance"
- TypePromoCode = "PromoCode"
- TypePromoCodeUsage = "PromoCodeUsage"
- TypeProxy = "Proxy"
- TypeRedeemCode = "RedeemCode"
- TypeSecuritySecret = "SecuritySecret"
- TypeSetting = "Setting"
- TypeSubscriptionPlan = "SubscriptionPlan"
- TypeTLSFingerprintProfile = "TLSFingerprintProfile"
- TypeUsageCleanupTask = "UsageCleanupTask"
- TypeUsageLog = "UsageLog"
- TypeUser = "User"
- TypeUserAllowedGroup = "UserAllowedGroup"
- TypeUserAttributeDefinition = "UserAttributeDefinition"
- TypeUserAttributeValue = "UserAttributeValue"
- TypeUserSubscription = "UserSubscription"
+ TypeAPIKey = "APIKey"
+ TypeAccount = "Account"
+ TypeAccountGroup = "AccountGroup"
+ TypeAnnouncement = "Announcement"
+ TypeAnnouncementRead = "AnnouncementRead"
+ TypeAuthIdentity = "AuthIdentity"
+ TypeAuthIdentityChannel = "AuthIdentityChannel"
+ TypeChannelMonitor = "ChannelMonitor"
+ TypeChannelMonitorDailyRollup = "ChannelMonitorDailyRollup"
+ TypeChannelMonitorHistory = "ChannelMonitorHistory"
+ TypeChannelMonitorRequestTemplate = "ChannelMonitorRequestTemplate"
+ TypeErrorPassthroughRule = "ErrorPassthroughRule"
+ TypeGroup = "Group"
+ TypeIdempotencyRecord = "IdempotencyRecord"
+ TypeIdentityAdoptionDecision = "IdentityAdoptionDecision"
+ TypePaymentAuditLog = "PaymentAuditLog"
+ TypePaymentOrder = "PaymentOrder"
+ TypePaymentProviderInstance = "PaymentProviderInstance"
+ TypePendingAuthSession = "PendingAuthSession"
+ TypePromoCode = "PromoCode"
+ TypePromoCodeUsage = "PromoCodeUsage"
+ TypeProxy = "Proxy"
+ TypeRedeemCode = "RedeemCode"
+ TypeSecuritySecret = "SecuritySecret"
+ TypeSetting = "Setting"
+ TypeSubscriptionPlan = "SubscriptionPlan"
+ TypeTLSFingerprintProfile = "TLSFingerprintProfile"
+ TypeUsageCleanupTask = "UsageCleanupTask"
+ TypeUsageLog = "UsageLog"
+ TypeUser = "User"
+ TypeUserAllowedGroup = "UserAllowedGroup"
+ TypeUserAttributeDefinition = "UserAttributeDefinition"
+ TypeUserAttributeValue = "UserAttributeValue"
+ TypeUserSubscription = "UserSubscription"
)
// APIKeyMutation represents an operation that mutates the APIKey nodes in the graph.
@@ -6887,6 +6903,6522 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown AnnouncementRead edge %s", name)
}
+// AuthIdentityMutation represents an operation that mutates the AuthIdentity nodes in the graph.
+type AuthIdentityMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ provider_type *string
+ provider_key *string
+ provider_subject *string
+ verified_at *time.Time
+ issuer *string
+ metadata *map[string]interface{}
+ clearedFields map[string]struct{}
+ user *int64
+ cleareduser bool
+ channels map[int64]struct{}
+ removedchannels map[int64]struct{}
+ clearedchannels bool
+ adoption_decisions map[int64]struct{}
+ removedadoption_decisions map[int64]struct{}
+ clearedadoption_decisions bool
+ done bool
+ oldValue func(context.Context) (*AuthIdentity, error)
+ predicates []predicate.AuthIdentity
+}
+
+var _ ent.Mutation = (*AuthIdentityMutation)(nil)
+
+// authidentityOption allows management of the mutation configuration using functional options.
+type authidentityOption func(*AuthIdentityMutation)
+
+// newAuthIdentityMutation creates new mutation for the AuthIdentity entity.
+func newAuthIdentityMutation(c config, op Op, opts ...authidentityOption) *AuthIdentityMutation {
+ m := &AuthIdentityMutation{
+ config: c,
+ op: op,
+ typ: TypeAuthIdentity,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withAuthIdentityID sets the ID field of the mutation.
+func withAuthIdentityID(id int64) authidentityOption {
+ return func(m *AuthIdentityMutation) {
+ var (
+ err error
+ once sync.Once
+ value *AuthIdentity
+ )
+ m.oldValue = func(ctx context.Context) (*AuthIdentity, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().AuthIdentity.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withAuthIdentity sets the old AuthIdentity of the mutation.
+func withAuthIdentity(node *AuthIdentity) authidentityOption {
+ return func(m *AuthIdentityMutation) {
+ m.oldValue = func(context.Context) (*AuthIdentity, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AuthIdentityMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AuthIdentityMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AuthIdentityMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AuthIdentityMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().AuthIdentity.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AuthIdentityMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AuthIdentityMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AuthIdentityMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *AuthIdentityMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *AuthIdentityMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *AuthIdentityMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetUserID sets the "user_id" field.
+func (m *AuthIdentityMutation) SetUserID(i int64) {
+ m.user = &i
+}
+
+// UserID returns the value of the "user_id" field in the mutation.
+func (m *AuthIdentityMutation) UserID() (r int64, exists bool) {
+ v := m.user
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUserID returns the old "user_id" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldUserID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUserID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUserID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUserID: %w", err)
+ }
+ return oldValue.UserID, nil
+}
+
+// ResetUserID resets all changes to the "user_id" field.
+func (m *AuthIdentityMutation) ResetUserID() {
+ m.user = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *AuthIdentityMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *AuthIdentityMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *AuthIdentityMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *AuthIdentityMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *AuthIdentityMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *AuthIdentityMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (m *AuthIdentityMutation) SetProviderSubject(s string) {
+ m.provider_subject = &s
+}
+
+// ProviderSubject returns the value of the "provider_subject" field in the mutation.
+func (m *AuthIdentityMutation) ProviderSubject() (r string, exists bool) {
+ v := m.provider_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSubject returns the old "provider_subject" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err)
+ }
+ return oldValue.ProviderSubject, nil
+}
+
+// ResetProviderSubject resets all changes to the "provider_subject" field.
+func (m *AuthIdentityMutation) ResetProviderSubject() {
+ m.provider_subject = nil
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (m *AuthIdentityMutation) SetVerifiedAt(t time.Time) {
+ m.verified_at = &t
+}
+
+// VerifiedAt returns the value of the "verified_at" field in the mutation.
+func (m *AuthIdentityMutation) VerifiedAt() (r time.Time, exists bool) {
+ v := m.verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldVerifiedAt returns the old "verified_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldVerifiedAt: %w", err)
+ }
+ return oldValue.VerifiedAt, nil
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (m *AuthIdentityMutation) ClearVerifiedAt() {
+ m.verified_at = nil
+ m.clearedFields[authidentity.FieldVerifiedAt] = struct{}{}
+}
+
+// VerifiedAtCleared returns if the "verified_at" field was cleared in this mutation.
+func (m *AuthIdentityMutation) VerifiedAtCleared() bool {
+ _, ok := m.clearedFields[authidentity.FieldVerifiedAt]
+ return ok
+}
+
+// ResetVerifiedAt resets all changes to the "verified_at" field.
+func (m *AuthIdentityMutation) ResetVerifiedAt() {
+ m.verified_at = nil
+ delete(m.clearedFields, authidentity.FieldVerifiedAt)
+}
+
+// SetIssuer sets the "issuer" field.
+func (m *AuthIdentityMutation) SetIssuer(s string) {
+ m.issuer = &s
+}
+
+// Issuer returns the value of the "issuer" field in the mutation.
+func (m *AuthIdentityMutation) Issuer() (r string, exists bool) {
+ v := m.issuer
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIssuer returns the old "issuer" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldIssuer(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIssuer is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIssuer requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIssuer: %w", err)
+ }
+ return oldValue.Issuer, nil
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (m *AuthIdentityMutation) ClearIssuer() {
+ m.issuer = nil
+ m.clearedFields[authidentity.FieldIssuer] = struct{}{}
+}
+
+// IssuerCleared returns if the "issuer" field was cleared in this mutation.
+func (m *AuthIdentityMutation) IssuerCleared() bool {
+ _, ok := m.clearedFields[authidentity.FieldIssuer]
+ return ok
+}
+
+// ResetIssuer resets all changes to the "issuer" field.
+func (m *AuthIdentityMutation) ResetIssuer() {
+ m.issuer = nil
+ delete(m.clearedFields, authidentity.FieldIssuer)
+}
+
+// SetMetadata sets the "metadata" field.
+func (m *AuthIdentityMutation) SetMetadata(value map[string]interface{}) {
+ m.metadata = &value
+}
+
+// Metadata returns the value of the "metadata" field in the mutation.
+func (m *AuthIdentityMutation) Metadata() (r map[string]interface{}, exists bool) {
+ v := m.metadata
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMetadata returns the old "metadata" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMetadata is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMetadata requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMetadata: %w", err)
+ }
+ return oldValue.Metadata, nil
+}
+
+// ResetMetadata resets all changes to the "metadata" field.
+func (m *AuthIdentityMutation) ResetMetadata() {
+ m.metadata = nil
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (m *AuthIdentityMutation) ClearUser() {
+ m.cleareduser = true
+ m.clearedFields[authidentity.FieldUserID] = struct{}{}
+}
+
+// UserCleared reports if the "user" edge to the User entity was cleared.
+func (m *AuthIdentityMutation) UserCleared() bool {
+ return m.cleareduser
+}
+
+// UserIDs returns the "user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// UserID instead. It exists only for internal usage by the builders.
+func (m *AuthIdentityMutation) UserIDs() (ids []int64) {
+ if id := m.user; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetUser resets all changes to the "user" edge.
+func (m *AuthIdentityMutation) ResetUser() {
+ m.user = nil
+ m.cleareduser = false
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by ids.
+func (m *AuthIdentityMutation) AddChannelIDs(ids ...int64) {
+ if m.channels == nil {
+ m.channels = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.channels[ids[i]] = struct{}{}
+ }
+}
+
+// ClearChannels clears the "channels" edge to the AuthIdentityChannel entity.
+func (m *AuthIdentityMutation) ClearChannels() {
+ m.clearedchannels = true
+}
+
+// ChannelsCleared reports if the "channels" edge to the AuthIdentityChannel entity was cleared.
+func (m *AuthIdentityMutation) ChannelsCleared() bool {
+ return m.clearedchannels
+}
+
+// RemoveChannelIDs removes the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (m *AuthIdentityMutation) RemoveChannelIDs(ids ...int64) {
+ if m.removedchannels == nil {
+ m.removedchannels = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.channels, ids[i])
+ m.removedchannels[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedChannels returns the removed IDs of the "channels" edge to the AuthIdentityChannel entity.
+func (m *AuthIdentityMutation) RemovedChannelsIDs() (ids []int64) {
+ for id := range m.removedchannels {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ChannelsIDs returns the "channels" edge IDs in the mutation.
+func (m *AuthIdentityMutation) ChannelsIDs() (ids []int64) {
+ for id := range m.channels {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetChannels resets all changes to the "channels" edge.
+func (m *AuthIdentityMutation) ResetChannels() {
+ m.channels = nil
+ m.clearedchannels = false
+ m.removedchannels = nil
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by ids.
+func (m *AuthIdentityMutation) AddAdoptionDecisionIDs(ids ...int64) {
+ if m.adoption_decisions == nil {
+ m.adoption_decisions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.adoption_decisions[ids[i]] = struct{}{}
+ }
+}
+
+// ClearAdoptionDecisions clears the "adoption_decisions" edge to the IdentityAdoptionDecision entity.
+func (m *AuthIdentityMutation) ClearAdoptionDecisions() {
+ m.clearedadoption_decisions = true
+}
+
+// AdoptionDecisionsCleared reports if the "adoption_decisions" edge to the IdentityAdoptionDecision entity was cleared.
+func (m *AuthIdentityMutation) AdoptionDecisionsCleared() bool {
+ return m.clearedadoption_decisions
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (m *AuthIdentityMutation) RemoveAdoptionDecisionIDs(ids ...int64) {
+ if m.removedadoption_decisions == nil {
+ m.removedadoption_decisions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.adoption_decisions, ids[i])
+ m.removedadoption_decisions[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedAdoptionDecisions returns the removed IDs of the "adoption_decisions" edge to the IdentityAdoptionDecision entity.
+func (m *AuthIdentityMutation) RemovedAdoptionDecisionsIDs() (ids []int64) {
+ for id := range m.removedadoption_decisions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// AdoptionDecisionsIDs returns the "adoption_decisions" edge IDs in the mutation.
+func (m *AuthIdentityMutation) AdoptionDecisionsIDs() (ids []int64) {
+ for id := range m.adoption_decisions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetAdoptionDecisions resets all changes to the "adoption_decisions" edge.
+func (m *AuthIdentityMutation) ResetAdoptionDecisions() {
+ m.adoption_decisions = nil
+ m.clearedadoption_decisions = false
+ m.removedadoption_decisions = nil
+}
+
+// Where appends a list predicates to the AuthIdentityMutation builder.
+func (m *AuthIdentityMutation) Where(ps ...predicate.AuthIdentity) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AuthIdentityMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AuthIdentityMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AuthIdentity, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AuthIdentityMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AuthIdentityMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (AuthIdentity).
+func (m *AuthIdentityMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AuthIdentityMutation) Fields() []string {
+ fields := make([]string, 0, 9)
+ if m.created_at != nil {
+ fields = append(fields, authidentity.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, authidentity.FieldUpdatedAt)
+ }
+ if m.user != nil {
+ fields = append(fields, authidentity.FieldUserID)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, authidentity.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, authidentity.FieldProviderKey)
+ }
+ if m.provider_subject != nil {
+ fields = append(fields, authidentity.FieldProviderSubject)
+ }
+ if m.verified_at != nil {
+ fields = append(fields, authidentity.FieldVerifiedAt)
+ }
+ if m.issuer != nil {
+ fields = append(fields, authidentity.FieldIssuer)
+ }
+ if m.metadata != nil {
+ fields = append(fields, authidentity.FieldMetadata)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AuthIdentityMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ return m.CreatedAt()
+ case authidentity.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case authidentity.FieldUserID:
+ return m.UserID()
+ case authidentity.FieldProviderType:
+ return m.ProviderType()
+ case authidentity.FieldProviderKey:
+ return m.ProviderKey()
+ case authidentity.FieldProviderSubject:
+ return m.ProviderSubject()
+ case authidentity.FieldVerifiedAt:
+ return m.VerifiedAt()
+ case authidentity.FieldIssuer:
+ return m.Issuer()
+ case authidentity.FieldMetadata:
+ return m.Metadata()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AuthIdentityMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case authidentity.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case authidentity.FieldUserID:
+ return m.OldUserID(ctx)
+ case authidentity.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case authidentity.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case authidentity.FieldProviderSubject:
+ return m.OldProviderSubject(ctx)
+ case authidentity.FieldVerifiedAt:
+ return m.OldVerifiedAt(ctx)
+ case authidentity.FieldIssuer:
+ return m.OldIssuer(ctx)
+ case authidentity.FieldMetadata:
+ return m.OldMetadata(ctx)
+ }
+ return nil, fmt.Errorf("unknown AuthIdentity field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case authidentity.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case authidentity.FieldUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserID(v)
+ return nil
+ case authidentity.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case authidentity.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case authidentity.FieldProviderSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSubject(v)
+ return nil
+ case authidentity.FieldVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetVerifiedAt(v)
+ return nil
+ case authidentity.FieldIssuer:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIssuer(v)
+ return nil
+ case authidentity.FieldMetadata:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMetadata(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AuthIdentityMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AuthIdentityMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown AuthIdentity numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AuthIdentityMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(authidentity.FieldVerifiedAt) {
+ fields = append(fields, authidentity.FieldVerifiedAt)
+ }
+ if m.FieldCleared(authidentity.FieldIssuer) {
+ fields = append(fields, authidentity.FieldIssuer)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AuthIdentityMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AuthIdentityMutation) ClearField(name string) error {
+ switch name {
+ case authidentity.FieldVerifiedAt:
+ m.ClearVerifiedAt()
+ return nil
+ case authidentity.FieldIssuer:
+ m.ClearIssuer()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AuthIdentityMutation) ResetField(name string) error {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case authidentity.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case authidentity.FieldUserID:
+ m.ResetUserID()
+ return nil
+ case authidentity.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case authidentity.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case authidentity.FieldProviderSubject:
+ m.ResetProviderSubject()
+ return nil
+ case authidentity.FieldVerifiedAt:
+ m.ResetVerifiedAt()
+ return nil
+ case authidentity.FieldIssuer:
+ m.ResetIssuer()
+ return nil
+ case authidentity.FieldMetadata:
+ m.ResetMetadata()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AuthIdentityMutation) AddedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.user != nil {
+ edges = append(edges, authidentity.EdgeUser)
+ }
+ if m.channels != nil {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.adoption_decisions != nil {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AuthIdentityMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case authidentity.EdgeUser:
+ if id := m.user; id != nil {
+ return []ent.Value{*id}
+ }
+ case authidentity.EdgeChannels:
+ ids := make([]ent.Value, 0, len(m.channels))
+ for id := range m.channels {
+ ids = append(ids, id)
+ }
+ return ids
+ case authidentity.EdgeAdoptionDecisions:
+ ids := make([]ent.Value, 0, len(m.adoption_decisions))
+ for id := range m.adoption_decisions {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AuthIdentityMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.removedchannels != nil {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.removedadoption_decisions != nil {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AuthIdentityMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case authidentity.EdgeChannels:
+ ids := make([]ent.Value, 0, len(m.removedchannels))
+ for id := range m.removedchannels {
+ ids = append(ids, id)
+ }
+ return ids
+ case authidentity.EdgeAdoptionDecisions:
+ ids := make([]ent.Value, 0, len(m.removedadoption_decisions))
+ for id := range m.removedadoption_decisions {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AuthIdentityMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.cleareduser {
+ edges = append(edges, authidentity.EdgeUser)
+ }
+ if m.clearedchannels {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.clearedadoption_decisions {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AuthIdentityMutation) EdgeCleared(name string) bool {
+ switch name {
+ case authidentity.EdgeUser:
+ return m.cleareduser
+ case authidentity.EdgeChannels:
+ return m.clearedchannels
+ case authidentity.EdgeAdoptionDecisions:
+ return m.clearedadoption_decisions
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AuthIdentityMutation) ClearEdge(name string) error {
+ switch name {
+ case authidentity.EdgeUser:
+ m.ClearUser()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AuthIdentityMutation) ResetEdge(name string) error {
+ switch name {
+ case authidentity.EdgeUser:
+ m.ResetUser()
+ return nil
+ case authidentity.EdgeChannels:
+ m.ResetChannels()
+ return nil
+ case authidentity.EdgeAdoptionDecisions:
+ m.ResetAdoptionDecisions()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity edge %s", name)
+}
+
+// AuthIdentityChannelMutation represents an operation that mutates the AuthIdentityChannel nodes in the graph.
+type AuthIdentityChannelMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ provider_type *string
+ provider_key *string
+ channel *string
+ channel_app_id *string
+ channel_subject *string
+ metadata *map[string]interface{}
+ clearedFields map[string]struct{}
+ identity *int64
+ clearedidentity bool
+ done bool
+ oldValue func(context.Context) (*AuthIdentityChannel, error)
+ predicates []predicate.AuthIdentityChannel
+}
+
+var _ ent.Mutation = (*AuthIdentityChannelMutation)(nil)
+
+// authidentitychannelOption allows management of the mutation configuration using functional options.
+type authidentitychannelOption func(*AuthIdentityChannelMutation)
+
+// newAuthIdentityChannelMutation creates new mutation for the AuthIdentityChannel entity.
+func newAuthIdentityChannelMutation(c config, op Op, opts ...authidentitychannelOption) *AuthIdentityChannelMutation {
+ m := &AuthIdentityChannelMutation{
+ config: c,
+ op: op,
+ typ: TypeAuthIdentityChannel,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withAuthIdentityChannelID sets the ID field of the mutation.
+func withAuthIdentityChannelID(id int64) authidentitychannelOption {
+ return func(m *AuthIdentityChannelMutation) {
+ var (
+ err error
+ once sync.Once
+ value *AuthIdentityChannel
+ )
+ m.oldValue = func(ctx context.Context) (*AuthIdentityChannel, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().AuthIdentityChannel.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withAuthIdentityChannel sets the old AuthIdentityChannel of the mutation.
+func withAuthIdentityChannel(node *AuthIdentityChannel) authidentitychannelOption {
+ return func(m *AuthIdentityChannelMutation) {
+ m.oldValue = func(context.Context) (*AuthIdentityChannel, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AuthIdentityChannelMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AuthIdentityChannelMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AuthIdentityChannelMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AuthIdentityChannelMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().AuthIdentityChannel.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AuthIdentityChannelMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AuthIdentityChannelMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AuthIdentityChannelMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *AuthIdentityChannelMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *AuthIdentityChannelMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *AuthIdentityChannelMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (m *AuthIdentityChannelMutation) SetIdentityID(i int64) {
+ m.identity = &i
+}
+
+// IdentityID returns the value of the "identity_id" field in the mutation.
+func (m *AuthIdentityChannelMutation) IdentityID() (r int64, exists bool) {
+ v := m.identity
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIdentityID returns the old "identity_id" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldIdentityID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIdentityID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIdentityID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIdentityID: %w", err)
+ }
+ return oldValue.IdentityID, nil
+}
+
+// ResetIdentityID resets all changes to the "identity_id" field.
+func (m *AuthIdentityChannelMutation) ResetIdentityID() {
+ m.identity = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *AuthIdentityChannelMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *AuthIdentityChannelMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *AuthIdentityChannelMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *AuthIdentityChannelMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *AuthIdentityChannelMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *AuthIdentityChannelMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetChannel sets the "channel" field.
+func (m *AuthIdentityChannelMutation) SetChannel(s string) {
+ m.channel = &s
+}
+
+// Channel returns the value of the "channel" field in the mutation.
+func (m *AuthIdentityChannelMutation) Channel() (r string, exists bool) {
+ v := m.channel
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannel returns the old "channel" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannel(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannel is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannel requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannel: %w", err)
+ }
+ return oldValue.Channel, nil
+}
+
+// ResetChannel resets all changes to the "channel" field.
+func (m *AuthIdentityChannelMutation) ResetChannel() {
+ m.channel = nil
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (m *AuthIdentityChannelMutation) SetChannelAppID(s string) {
+ m.channel_app_id = &s
+}
+
+// ChannelAppID returns the value of the "channel_app_id" field in the mutation.
+func (m *AuthIdentityChannelMutation) ChannelAppID() (r string, exists bool) {
+ v := m.channel_app_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannelAppID returns the old "channel_app_id" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannelAppID(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannelAppID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannelAppID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannelAppID: %w", err)
+ }
+ return oldValue.ChannelAppID, nil
+}
+
+// ResetChannelAppID resets all changes to the "channel_app_id" field.
+func (m *AuthIdentityChannelMutation) ResetChannelAppID() {
+ m.channel_app_id = nil
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (m *AuthIdentityChannelMutation) SetChannelSubject(s string) {
+ m.channel_subject = &s
+}
+
+// ChannelSubject returns the value of the "channel_subject" field in the mutation.
+func (m *AuthIdentityChannelMutation) ChannelSubject() (r string, exists bool) {
+ v := m.channel_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannelSubject returns the old "channel_subject" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannelSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannelSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannelSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannelSubject: %w", err)
+ }
+ return oldValue.ChannelSubject, nil
+}
+
+// ResetChannelSubject resets all changes to the "channel_subject" field.
+func (m *AuthIdentityChannelMutation) ResetChannelSubject() {
+ m.channel_subject = nil
+}
+
+// SetMetadata sets the "metadata" field.
+func (m *AuthIdentityChannelMutation) SetMetadata(value map[string]interface{}) {
+ m.metadata = &value
+}
+
+// Metadata returns the value of the "metadata" field in the mutation.
+func (m *AuthIdentityChannelMutation) Metadata() (r map[string]interface{}, exists bool) {
+ v := m.metadata
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMetadata returns the old "metadata" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMetadata is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMetadata requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMetadata: %w", err)
+ }
+ return oldValue.Metadata, nil
+}
+
+// ResetMetadata resets all changes to the "metadata" field.
+func (m *AuthIdentityChannelMutation) ResetMetadata() {
+ m.metadata = nil
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (m *AuthIdentityChannelMutation) ClearIdentity() {
+ m.clearedidentity = true
+ m.clearedFields[authidentitychannel.FieldIdentityID] = struct{}{}
+}
+
+// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared.
+func (m *AuthIdentityChannelMutation) IdentityCleared() bool {
+ return m.clearedidentity
+}
+
+// IdentityIDs returns the "identity" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// IdentityID instead. It exists only for internal usage by the builders.
+func (m *AuthIdentityChannelMutation) IdentityIDs() (ids []int64) {
+ if id := m.identity; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetIdentity resets all changes to the "identity" edge.
+func (m *AuthIdentityChannelMutation) ResetIdentity() {
+ m.identity = nil
+ m.clearedidentity = false
+}
+
+// Where appends a list predicates to the AuthIdentityChannelMutation builder.
+func (m *AuthIdentityChannelMutation) Where(ps ...predicate.AuthIdentityChannel) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AuthIdentityChannelMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AuthIdentityChannelMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AuthIdentityChannel, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AuthIdentityChannelMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AuthIdentityChannelMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (AuthIdentityChannel).
+func (m *AuthIdentityChannelMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AuthIdentityChannelMutation) Fields() []string {
+ fields := make([]string, 0, 9)
+ if m.created_at != nil {
+ fields = append(fields, authidentitychannel.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, authidentitychannel.FieldUpdatedAt)
+ }
+ if m.identity != nil {
+ fields = append(fields, authidentitychannel.FieldIdentityID)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, authidentitychannel.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, authidentitychannel.FieldProviderKey)
+ }
+ if m.channel != nil {
+ fields = append(fields, authidentitychannel.FieldChannel)
+ }
+ if m.channel_app_id != nil {
+ fields = append(fields, authidentitychannel.FieldChannelAppID)
+ }
+ if m.channel_subject != nil {
+ fields = append(fields, authidentitychannel.FieldChannelSubject)
+ }
+ if m.metadata != nil {
+ fields = append(fields, authidentitychannel.FieldMetadata)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AuthIdentityChannelMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ return m.CreatedAt()
+ case authidentitychannel.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case authidentitychannel.FieldIdentityID:
+ return m.IdentityID()
+ case authidentitychannel.FieldProviderType:
+ return m.ProviderType()
+ case authidentitychannel.FieldProviderKey:
+ return m.ProviderKey()
+ case authidentitychannel.FieldChannel:
+ return m.Channel()
+ case authidentitychannel.FieldChannelAppID:
+ return m.ChannelAppID()
+ case authidentitychannel.FieldChannelSubject:
+ return m.ChannelSubject()
+ case authidentitychannel.FieldMetadata:
+ return m.Metadata()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AuthIdentityChannelMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case authidentitychannel.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case authidentitychannel.FieldIdentityID:
+ return m.OldIdentityID(ctx)
+ case authidentitychannel.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case authidentitychannel.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case authidentitychannel.FieldChannel:
+ return m.OldChannel(ctx)
+ case authidentitychannel.FieldChannelAppID:
+ return m.OldChannelAppID(ctx)
+ case authidentitychannel.FieldChannelSubject:
+ return m.OldChannelSubject(ctx)
+ case authidentitychannel.FieldMetadata:
+ return m.OldMetadata(ctx)
+ }
+ return nil, fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityChannelMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case authidentitychannel.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case authidentitychannel.FieldIdentityID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIdentityID(v)
+ return nil
+ case authidentitychannel.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case authidentitychannel.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case authidentitychannel.FieldChannel:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannel(v)
+ return nil
+ case authidentitychannel.FieldChannelAppID:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannelAppID(v)
+ return nil
+ case authidentitychannel.FieldChannelSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannelSubject(v)
+ return nil
+ case authidentitychannel.FieldMetadata:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMetadata(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AuthIdentityChannelMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AuthIdentityChannelMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityChannelMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AuthIdentityChannelMutation) ClearedFields() []string {
+ return nil
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AuthIdentityChannelMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ClearField(name string) error {
+ return fmt.Errorf("unknown AuthIdentityChannel nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ResetField(name string) error {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case authidentitychannel.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case authidentitychannel.FieldIdentityID:
+ m.ResetIdentityID()
+ return nil
+ case authidentitychannel.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case authidentitychannel.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case authidentitychannel.FieldChannel:
+ m.ResetChannel()
+ return nil
+ case authidentitychannel.FieldChannelAppID:
+ m.ResetChannelAppID()
+ return nil
+ case authidentitychannel.FieldChannelSubject:
+ m.ResetChannelSubject()
+ return nil
+ case authidentitychannel.FieldMetadata:
+ m.ResetMetadata()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AuthIdentityChannelMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.identity != nil {
+ edges = append(edges, authidentitychannel.EdgeIdentity)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AuthIdentityChannelMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ if id := m.identity; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AuthIdentityChannelMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AuthIdentityChannelMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AuthIdentityChannelMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedidentity {
+ edges = append(edges, authidentitychannel.EdgeIdentity)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AuthIdentityChannelMutation) EdgeCleared(name string) bool {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ return m.clearedidentity
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ClearEdge(name string) error {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ m.ClearIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ResetEdge(name string) error {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ m.ResetIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel edge %s", name)
+}
+
+// ChannelMonitorMutation represents an operation that mutates the ChannelMonitor nodes in the graph.
+type ChannelMonitorMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ name *string
+ provider *channelmonitor.Provider
+ endpoint *string
+ api_key_encrypted *string
+ primary_model *string
+ extra_models *[]string
+ appendextra_models []string
+ group_name *string
+ enabled *bool
+ interval_seconds *int
+ addinterval_seconds *int
+ last_checked_at *time.Time
+ created_by *int64
+ addcreated_by *int64
+ extra_headers *map[string]string
+ body_override_mode *string
+ body_override *map[string]interface{}
+ clearedFields map[string]struct{}
+ history map[int64]struct{}
+ removedhistory map[int64]struct{}
+ clearedhistory bool
+ daily_rollups map[int64]struct{}
+ removeddaily_rollups map[int64]struct{}
+ cleareddaily_rollups bool
+ request_template *int64
+ clearedrequest_template bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitor, error)
+ predicates []predicate.ChannelMonitor
+}
+
+var _ ent.Mutation = (*ChannelMonitorMutation)(nil)
+
+// channelmonitorOption allows management of the mutation configuration using functional options.
+type channelmonitorOption func(*ChannelMonitorMutation)
+
+// newChannelMonitorMutation creates new mutation for the ChannelMonitor entity.
+func newChannelMonitorMutation(c config, op Op, opts ...channelmonitorOption) *ChannelMonitorMutation {
+ m := &ChannelMonitorMutation{
+ config: c,
+ op: op,
+ typ: TypeChannelMonitor,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withChannelMonitorID sets the ID field of the mutation.
+func withChannelMonitorID(id int64) channelmonitorOption {
+ return func(m *ChannelMonitorMutation) {
+ var (
+ err error
+ once sync.Once
+ value *ChannelMonitor
+ )
+ m.oldValue = func(ctx context.Context) (*ChannelMonitor, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().ChannelMonitor.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withChannelMonitor sets the old ChannelMonitor of the mutation.
+func withChannelMonitor(node *ChannelMonitor) channelmonitorOption {
+ return func(m *ChannelMonitorMutation) {
+ m.oldValue = func(context.Context) (*ChannelMonitor, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m ChannelMonitorMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m ChannelMonitorMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *ChannelMonitorMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *ChannelMonitorMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().ChannelMonitor.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *ChannelMonitorMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *ChannelMonitorMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *ChannelMonitorMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *ChannelMonitorMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *ChannelMonitorMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *ChannelMonitorMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetName sets the "name" field.
+func (m *ChannelMonitorMutation) SetName(s string) {
+ m.name = &s
+}
+
+// Name returns the value of the "name" field in the mutation.
+func (m *ChannelMonitorMutation) Name() (r string, exists bool) {
+ v := m.name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldName returns the old "name" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldName: %w", err)
+ }
+ return oldValue.Name, nil
+}
+
+// ResetName resets all changes to the "name" field.
+func (m *ChannelMonitorMutation) ResetName() {
+ m.name = nil
+}
+
+// SetProvider sets the "provider" field.
+func (m *ChannelMonitorMutation) SetProvider(c channelmonitor.Provider) {
+ m.provider = &c
+}
+
+// Provider returns the value of the "provider" field in the mutation.
+func (m *ChannelMonitorMutation) Provider() (r channelmonitor.Provider, exists bool) {
+ v := m.provider
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProvider returns the old "provider" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldProvider(ctx context.Context) (v channelmonitor.Provider, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProvider is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProvider requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProvider: %w", err)
+ }
+ return oldValue.Provider, nil
+}
+
+// ResetProvider resets all changes to the "provider" field.
+func (m *ChannelMonitorMutation) ResetProvider() {
+ m.provider = nil
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (m *ChannelMonitorMutation) SetEndpoint(s string) {
+ m.endpoint = &s
+}
+
+// Endpoint returns the value of the "endpoint" field in the mutation.
+func (m *ChannelMonitorMutation) Endpoint() (r string, exists bool) {
+ v := m.endpoint
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEndpoint returns the old "endpoint" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldEndpoint(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEndpoint is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEndpoint requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEndpoint: %w", err)
+ }
+ return oldValue.Endpoint, nil
+}
+
+// ResetEndpoint resets all changes to the "endpoint" field.
+func (m *ChannelMonitorMutation) ResetEndpoint() {
+ m.endpoint = nil
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (m *ChannelMonitorMutation) SetAPIKeyEncrypted(s string) {
+ m.api_key_encrypted = &s
+}
+
+// APIKeyEncrypted returns the value of the "api_key_encrypted" field in the mutation.
+func (m *ChannelMonitorMutation) APIKeyEncrypted() (r string, exists bool) {
+ v := m.api_key_encrypted
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAPIKeyEncrypted returns the old "api_key_encrypted" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldAPIKeyEncrypted(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAPIKeyEncrypted is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAPIKeyEncrypted requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAPIKeyEncrypted: %w", err)
+ }
+ return oldValue.APIKeyEncrypted, nil
+}
+
+// ResetAPIKeyEncrypted resets all changes to the "api_key_encrypted" field.
+func (m *ChannelMonitorMutation) ResetAPIKeyEncrypted() {
+ m.api_key_encrypted = nil
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (m *ChannelMonitorMutation) SetPrimaryModel(s string) {
+ m.primary_model = &s
+}
+
+// PrimaryModel returns the value of the "primary_model" field in the mutation.
+func (m *ChannelMonitorMutation) PrimaryModel() (r string, exists bool) {
+ v := m.primary_model
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPrimaryModel returns the old "primary_model" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldPrimaryModel(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPrimaryModel is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPrimaryModel requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPrimaryModel: %w", err)
+ }
+ return oldValue.PrimaryModel, nil
+}
+
+// ResetPrimaryModel resets all changes to the "primary_model" field.
+func (m *ChannelMonitorMutation) ResetPrimaryModel() {
+ m.primary_model = nil
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (m *ChannelMonitorMutation) SetExtraModels(s []string) {
+ m.extra_models = &s
+ m.appendextra_models = nil
+}
+
+// ExtraModels returns the value of the "extra_models" field in the mutation.
+func (m *ChannelMonitorMutation) ExtraModels() (r []string, exists bool) {
+ v := m.extra_models
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExtraModels returns the old "extra_models" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldExtraModels(ctx context.Context) (v []string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExtraModels is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExtraModels requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExtraModels: %w", err)
+ }
+ return oldValue.ExtraModels, nil
+}
+
+// AppendExtraModels adds s to the "extra_models" field.
+func (m *ChannelMonitorMutation) AppendExtraModels(s []string) {
+ m.appendextra_models = append(m.appendextra_models, s...)
+}
+
+// AppendedExtraModels returns the list of values that were appended to the "extra_models" field in this mutation.
+func (m *ChannelMonitorMutation) AppendedExtraModels() ([]string, bool) {
+ if len(m.appendextra_models) == 0 {
+ return nil, false
+ }
+ return m.appendextra_models, true
+}
+
+// ResetExtraModels resets all changes to the "extra_models" field.
+func (m *ChannelMonitorMutation) ResetExtraModels() {
+ m.extra_models = nil
+ m.appendextra_models = nil
+}
+
+// SetGroupName sets the "group_name" field.
+func (m *ChannelMonitorMutation) SetGroupName(s string) {
+ m.group_name = &s
+}
+
+// GroupName returns the value of the "group_name" field in the mutation.
+func (m *ChannelMonitorMutation) GroupName() (r string, exists bool) {
+ v := m.group_name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldGroupName returns the old "group_name" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldGroupName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldGroupName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldGroupName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldGroupName: %w", err)
+ }
+ return oldValue.GroupName, nil
+}
+
+// ClearGroupName clears the value of the "group_name" field.
+func (m *ChannelMonitorMutation) ClearGroupName() {
+ m.group_name = nil
+ m.clearedFields[channelmonitor.FieldGroupName] = struct{}{}
+}
+
+// GroupNameCleared returns if the "group_name" field was cleared in this mutation.
+func (m *ChannelMonitorMutation) GroupNameCleared() bool {
+ _, ok := m.clearedFields[channelmonitor.FieldGroupName]
+ return ok
+}
+
+// ResetGroupName resets all changes to the "group_name" field.
+func (m *ChannelMonitorMutation) ResetGroupName() {
+ m.group_name = nil
+ delete(m.clearedFields, channelmonitor.FieldGroupName)
+}
+
+// SetEnabled sets the "enabled" field.
+func (m *ChannelMonitorMutation) SetEnabled(b bool) {
+ m.enabled = &b
+}
+
+// Enabled returns the value of the "enabled" field in the mutation.
+func (m *ChannelMonitorMutation) Enabled() (r bool, exists bool) {
+ v := m.enabled
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEnabled returns the old "enabled" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEnabled is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEnabled requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEnabled: %w", err)
+ }
+ return oldValue.Enabled, nil
+}
+
+// ResetEnabled resets all changes to the "enabled" field.
+func (m *ChannelMonitorMutation) ResetEnabled() {
+ m.enabled = nil
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (m *ChannelMonitorMutation) SetIntervalSeconds(i int) {
+ m.interval_seconds = &i
+ m.addinterval_seconds = nil
+}
+
+// IntervalSeconds returns the value of the "interval_seconds" field in the mutation.
+func (m *ChannelMonitorMutation) IntervalSeconds() (r int, exists bool) {
+ v := m.interval_seconds
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIntervalSeconds returns the old "interval_seconds" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldIntervalSeconds(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIntervalSeconds is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIntervalSeconds requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIntervalSeconds: %w", err)
+ }
+ return oldValue.IntervalSeconds, nil
+}
+
+// AddIntervalSeconds adds i to the "interval_seconds" field.
+func (m *ChannelMonitorMutation) AddIntervalSeconds(i int) {
+ if m.addinterval_seconds != nil {
+ *m.addinterval_seconds += i
+ } else {
+ m.addinterval_seconds = &i
+ }
+}
+
+// AddedIntervalSeconds returns the value that was added to the "interval_seconds" field in this mutation.
+func (m *ChannelMonitorMutation) AddedIntervalSeconds() (r int, exists bool) {
+ v := m.addinterval_seconds
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetIntervalSeconds resets all changes to the "interval_seconds" field.
+func (m *ChannelMonitorMutation) ResetIntervalSeconds() {
+ m.interval_seconds = nil
+ m.addinterval_seconds = nil
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (m *ChannelMonitorMutation) SetLastCheckedAt(t time.Time) {
+ m.last_checked_at = &t
+}
+
+// LastCheckedAt returns the value of the "last_checked_at" field in the mutation.
+func (m *ChannelMonitorMutation) LastCheckedAt() (r time.Time, exists bool) {
+ v := m.last_checked_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLastCheckedAt returns the old "last_checked_at" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldLastCheckedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLastCheckedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLastCheckedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLastCheckedAt: %w", err)
+ }
+ return oldValue.LastCheckedAt, nil
+}
+
+// ClearLastCheckedAt clears the value of the "last_checked_at" field.
+func (m *ChannelMonitorMutation) ClearLastCheckedAt() {
+ m.last_checked_at = nil
+ m.clearedFields[channelmonitor.FieldLastCheckedAt] = struct{}{}
+}
+
+// LastCheckedAtCleared returns if the "last_checked_at" field was cleared in this mutation.
+func (m *ChannelMonitorMutation) LastCheckedAtCleared() bool {
+ _, ok := m.clearedFields[channelmonitor.FieldLastCheckedAt]
+ return ok
+}
+
+// ResetLastCheckedAt resets all changes to the "last_checked_at" field.
+func (m *ChannelMonitorMutation) ResetLastCheckedAt() {
+ m.last_checked_at = nil
+ delete(m.clearedFields, channelmonitor.FieldLastCheckedAt)
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (m *ChannelMonitorMutation) SetCreatedBy(i int64) {
+ m.created_by = &i
+ m.addcreated_by = nil
+}
+
+// CreatedBy returns the value of the "created_by" field in the mutation.
+func (m *ChannelMonitorMutation) CreatedBy() (r int64, exists bool) {
+ v := m.created_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedBy returns the old "created_by" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldCreatedBy(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedBy is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedBy requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedBy: %w", err)
+ }
+ return oldValue.CreatedBy, nil
+}
+
+// AddCreatedBy adds i to the "created_by" field.
+func (m *ChannelMonitorMutation) AddCreatedBy(i int64) {
+ if m.addcreated_by != nil {
+ *m.addcreated_by += i
+ } else {
+ m.addcreated_by = &i
+ }
+}
+
+// AddedCreatedBy returns the value that was added to the "created_by" field in this mutation.
+func (m *ChannelMonitorMutation) AddedCreatedBy() (r int64, exists bool) {
+ v := m.addcreated_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetCreatedBy resets all changes to the "created_by" field.
+func (m *ChannelMonitorMutation) ResetCreatedBy() {
+ m.created_by = nil
+ m.addcreated_by = nil
+}
+
+// SetTemplateID sets the "template_id" field.
+func (m *ChannelMonitorMutation) SetTemplateID(i int64) {
+ m.request_template = &i
+}
+
+// TemplateID returns the value of the "template_id" field in the mutation.
+func (m *ChannelMonitorMutation) TemplateID() (r int64, exists bool) {
+ v := m.request_template
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTemplateID returns the old "template_id" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldTemplateID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTemplateID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTemplateID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTemplateID: %w", err)
+ }
+ return oldValue.TemplateID, nil
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (m *ChannelMonitorMutation) ClearTemplateID() {
+ m.request_template = nil
+ m.clearedFields[channelmonitor.FieldTemplateID] = struct{}{}
+}
+
+// TemplateIDCleared returns if the "template_id" field was cleared in this mutation.
+func (m *ChannelMonitorMutation) TemplateIDCleared() bool {
+ _, ok := m.clearedFields[channelmonitor.FieldTemplateID]
+ return ok
+}
+
+// ResetTemplateID resets all changes to the "template_id" field.
+func (m *ChannelMonitorMutation) ResetTemplateID() {
+ m.request_template = nil
+ delete(m.clearedFields, channelmonitor.FieldTemplateID)
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (m *ChannelMonitorMutation) SetExtraHeaders(value map[string]string) {
+ m.extra_headers = &value
+}
+
+// ExtraHeaders returns the value of the "extra_headers" field in the mutation.
+func (m *ChannelMonitorMutation) ExtraHeaders() (r map[string]string, exists bool) {
+ v := m.extra_headers
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExtraHeaders returns the old "extra_headers" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldExtraHeaders(ctx context.Context) (v map[string]string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExtraHeaders is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExtraHeaders requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExtraHeaders: %w", err)
+ }
+ return oldValue.ExtraHeaders, nil
+}
+
+// ResetExtraHeaders resets all changes to the "extra_headers" field.
+func (m *ChannelMonitorMutation) ResetExtraHeaders() {
+ m.extra_headers = nil
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (m *ChannelMonitorMutation) SetBodyOverrideMode(s string) {
+ m.body_override_mode = &s
+}
+
+// BodyOverrideMode returns the value of the "body_override_mode" field in the mutation.
+func (m *ChannelMonitorMutation) BodyOverrideMode() (r string, exists bool) {
+ v := m.body_override_mode
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBodyOverrideMode returns the old "body_override_mode" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldBodyOverrideMode(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBodyOverrideMode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBodyOverrideMode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBodyOverrideMode: %w", err)
+ }
+ return oldValue.BodyOverrideMode, nil
+}
+
+// ResetBodyOverrideMode resets all changes to the "body_override_mode" field.
+func (m *ChannelMonitorMutation) ResetBodyOverrideMode() {
+ m.body_override_mode = nil
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (m *ChannelMonitorMutation) SetBodyOverride(value map[string]interface{}) {
+ m.body_override = &value
+}
+
+// BodyOverride returns the value of the "body_override" field in the mutation.
+func (m *ChannelMonitorMutation) BodyOverride() (r map[string]interface{}, exists bool) {
+ v := m.body_override
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBodyOverride returns the old "body_override" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldBodyOverride(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBodyOverride is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBodyOverride requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBodyOverride: %w", err)
+ }
+ return oldValue.BodyOverride, nil
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (m *ChannelMonitorMutation) ClearBodyOverride() {
+ m.body_override = nil
+ m.clearedFields[channelmonitor.FieldBodyOverride] = struct{}{}
+}
+
+// BodyOverrideCleared returns if the "body_override" field was cleared in this mutation.
+func (m *ChannelMonitorMutation) BodyOverrideCleared() bool {
+ _, ok := m.clearedFields[channelmonitor.FieldBodyOverride]
+ return ok
+}
+
+// ResetBodyOverride resets all changes to the "body_override" field.
+func (m *ChannelMonitorMutation) ResetBodyOverride() {
+ m.body_override = nil
+ delete(m.clearedFields, channelmonitor.FieldBodyOverride)
+}
+
+// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by ids.
+func (m *ChannelMonitorMutation) AddHistoryIDs(ids ...int64) {
+ if m.history == nil {
+ m.history = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.history[ids[i]] = struct{}{}
+ }
+}
+
+// ClearHistory clears the "history" edge to the ChannelMonitorHistory entity.
+func (m *ChannelMonitorMutation) ClearHistory() {
+ m.clearedhistory = true
+}
+
+// HistoryCleared reports if the "history" edge to the ChannelMonitorHistory entity was cleared.
+func (m *ChannelMonitorMutation) HistoryCleared() bool {
+ return m.clearedhistory
+}
+
+// RemoveHistoryIDs removes the "history" edge to the ChannelMonitorHistory entity by IDs.
+func (m *ChannelMonitorMutation) RemoveHistoryIDs(ids ...int64) {
+ if m.removedhistory == nil {
+ m.removedhistory = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.history, ids[i])
+ m.removedhistory[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedHistory returns the removed IDs of the "history" edge to the ChannelMonitorHistory entity.
+func (m *ChannelMonitorMutation) RemovedHistoryIDs() (ids []int64) {
+ for id := range m.removedhistory {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// HistoryIDs returns the "history" edge IDs in the mutation.
+func (m *ChannelMonitorMutation) HistoryIDs() (ids []int64) {
+ for id := range m.history {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetHistory resets all changes to the "history" edge.
+func (m *ChannelMonitorMutation) ResetHistory() {
+ m.history = nil
+ m.clearedhistory = false
+ m.removedhistory = nil
+}
+
+// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by ids.
+func (m *ChannelMonitorMutation) AddDailyRollupIDs(ids ...int64) {
+ if m.daily_rollups == nil {
+ m.daily_rollups = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.daily_rollups[ids[i]] = struct{}{}
+ }
+}
+
+// ClearDailyRollups clears the "daily_rollups" edge to the ChannelMonitorDailyRollup entity.
+func (m *ChannelMonitorMutation) ClearDailyRollups() {
+ m.cleareddaily_rollups = true
+}
+
+// DailyRollupsCleared reports if the "daily_rollups" edge to the ChannelMonitorDailyRollup entity was cleared.
+func (m *ChannelMonitorMutation) DailyRollupsCleared() bool {
+ return m.cleareddaily_rollups
+}
+
+// RemoveDailyRollupIDs removes the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs.
+func (m *ChannelMonitorMutation) RemoveDailyRollupIDs(ids ...int64) {
+ if m.removeddaily_rollups == nil {
+ m.removeddaily_rollups = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.daily_rollups, ids[i])
+ m.removeddaily_rollups[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedDailyRollups returns the removed IDs of the "daily_rollups" edge to the ChannelMonitorDailyRollup entity.
+func (m *ChannelMonitorMutation) RemovedDailyRollupsIDs() (ids []int64) {
+ for id := range m.removeddaily_rollups {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// DailyRollupsIDs returns the "daily_rollups" edge IDs in the mutation.
+func (m *ChannelMonitorMutation) DailyRollupsIDs() (ids []int64) {
+ for id := range m.daily_rollups {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetDailyRollups resets all changes to the "daily_rollups" edge.
+func (m *ChannelMonitorMutation) ResetDailyRollups() {
+ m.daily_rollups = nil
+ m.cleareddaily_rollups = false
+ m.removeddaily_rollups = nil
+}
+
+// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by id.
+func (m *ChannelMonitorMutation) SetRequestTemplateID(id int64) {
+ m.request_template = &id
+}
+
+// ClearRequestTemplate clears the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (m *ChannelMonitorMutation) ClearRequestTemplate() {
+ m.clearedrequest_template = true
+ m.clearedFields[channelmonitor.FieldTemplateID] = struct{}{}
+}
+
+// RequestTemplateCleared reports if the "request_template" edge to the ChannelMonitorRequestTemplate entity was cleared.
+func (m *ChannelMonitorMutation) RequestTemplateCleared() bool {
+ return m.TemplateIDCleared() || m.clearedrequest_template
+}
+
+// RequestTemplateID returns the "request_template" edge ID in the mutation.
+func (m *ChannelMonitorMutation) RequestTemplateID() (id int64, exists bool) {
+ if m.request_template != nil {
+ return *m.request_template, true
+ }
+ return
+}
+
+// RequestTemplateIDs returns the "request_template" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// RequestTemplateID instead. It exists only for internal usage by the builders.
+func (m *ChannelMonitorMutation) RequestTemplateIDs() (ids []int64) {
+ if id := m.request_template; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetRequestTemplate resets all changes to the "request_template" edge.
+func (m *ChannelMonitorMutation) ResetRequestTemplate() {
+ m.request_template = nil
+ m.clearedrequest_template = false
+}
+
+// Where appends a list predicates to the ChannelMonitorMutation builder.
+func (m *ChannelMonitorMutation) Where(ps ...predicate.ChannelMonitor) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the ChannelMonitorMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *ChannelMonitorMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.ChannelMonitor, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *ChannelMonitorMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *ChannelMonitorMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (ChannelMonitor).
+func (m *ChannelMonitorMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *ChannelMonitorMutation) Fields() []string {
+ fields := make([]string, 0, 17)
+ if m.created_at != nil {
+ fields = append(fields, channelmonitor.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, channelmonitor.FieldUpdatedAt)
+ }
+ if m.name != nil {
+ fields = append(fields, channelmonitor.FieldName)
+ }
+ if m.provider != nil {
+ fields = append(fields, channelmonitor.FieldProvider)
+ }
+ if m.endpoint != nil {
+ fields = append(fields, channelmonitor.FieldEndpoint)
+ }
+ if m.api_key_encrypted != nil {
+ fields = append(fields, channelmonitor.FieldAPIKeyEncrypted)
+ }
+ if m.primary_model != nil {
+ fields = append(fields, channelmonitor.FieldPrimaryModel)
+ }
+ if m.extra_models != nil {
+ fields = append(fields, channelmonitor.FieldExtraModels)
+ }
+ if m.group_name != nil {
+ fields = append(fields, channelmonitor.FieldGroupName)
+ }
+ if m.enabled != nil {
+ fields = append(fields, channelmonitor.FieldEnabled)
+ }
+ if m.interval_seconds != nil {
+ fields = append(fields, channelmonitor.FieldIntervalSeconds)
+ }
+ if m.last_checked_at != nil {
+ fields = append(fields, channelmonitor.FieldLastCheckedAt)
+ }
+ if m.created_by != nil {
+ fields = append(fields, channelmonitor.FieldCreatedBy)
+ }
+ if m.request_template != nil {
+ fields = append(fields, channelmonitor.FieldTemplateID)
+ }
+ if m.extra_headers != nil {
+ fields = append(fields, channelmonitor.FieldExtraHeaders)
+ }
+ if m.body_override_mode != nil {
+ fields = append(fields, channelmonitor.FieldBodyOverrideMode)
+ }
+ if m.body_override != nil {
+ fields = append(fields, channelmonitor.FieldBodyOverride)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *ChannelMonitorMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitor.FieldCreatedAt:
+ return m.CreatedAt()
+ case channelmonitor.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case channelmonitor.FieldName:
+ return m.Name()
+ case channelmonitor.FieldProvider:
+ return m.Provider()
+ case channelmonitor.FieldEndpoint:
+ return m.Endpoint()
+ case channelmonitor.FieldAPIKeyEncrypted:
+ return m.APIKeyEncrypted()
+ case channelmonitor.FieldPrimaryModel:
+ return m.PrimaryModel()
+ case channelmonitor.FieldExtraModels:
+ return m.ExtraModels()
+ case channelmonitor.FieldGroupName:
+ return m.GroupName()
+ case channelmonitor.FieldEnabled:
+ return m.Enabled()
+ case channelmonitor.FieldIntervalSeconds:
+ return m.IntervalSeconds()
+ case channelmonitor.FieldLastCheckedAt:
+ return m.LastCheckedAt()
+ case channelmonitor.FieldCreatedBy:
+ return m.CreatedBy()
+ case channelmonitor.FieldTemplateID:
+ return m.TemplateID()
+ case channelmonitor.FieldExtraHeaders:
+ return m.ExtraHeaders()
+ case channelmonitor.FieldBodyOverrideMode:
+ return m.BodyOverrideMode()
+ case channelmonitor.FieldBodyOverride:
+ return m.BodyOverride()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *ChannelMonitorMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case channelmonitor.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case channelmonitor.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case channelmonitor.FieldName:
+ return m.OldName(ctx)
+ case channelmonitor.FieldProvider:
+ return m.OldProvider(ctx)
+ case channelmonitor.FieldEndpoint:
+ return m.OldEndpoint(ctx)
+ case channelmonitor.FieldAPIKeyEncrypted:
+ return m.OldAPIKeyEncrypted(ctx)
+ case channelmonitor.FieldPrimaryModel:
+ return m.OldPrimaryModel(ctx)
+ case channelmonitor.FieldExtraModels:
+ return m.OldExtraModels(ctx)
+ case channelmonitor.FieldGroupName:
+ return m.OldGroupName(ctx)
+ case channelmonitor.FieldEnabled:
+ return m.OldEnabled(ctx)
+ case channelmonitor.FieldIntervalSeconds:
+ return m.OldIntervalSeconds(ctx)
+ case channelmonitor.FieldLastCheckedAt:
+ return m.OldLastCheckedAt(ctx)
+ case channelmonitor.FieldCreatedBy:
+ return m.OldCreatedBy(ctx)
+ case channelmonitor.FieldTemplateID:
+ return m.OldTemplateID(ctx)
+ case channelmonitor.FieldExtraHeaders:
+ return m.OldExtraHeaders(ctx)
+ case channelmonitor.FieldBodyOverrideMode:
+ return m.OldBodyOverrideMode(ctx)
+ case channelmonitor.FieldBodyOverride:
+ return m.OldBodyOverride(ctx)
+ }
+ return nil, fmt.Errorf("unknown ChannelMonitor field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitor.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case channelmonitor.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case channelmonitor.FieldName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetName(v)
+ return nil
+ case channelmonitor.FieldProvider:
+ v, ok := value.(channelmonitor.Provider)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProvider(v)
+ return nil
+ case channelmonitor.FieldEndpoint:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEndpoint(v)
+ return nil
+ case channelmonitor.FieldAPIKeyEncrypted:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAPIKeyEncrypted(v)
+ return nil
+ case channelmonitor.FieldPrimaryModel:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPrimaryModel(v)
+ return nil
+ case channelmonitor.FieldExtraModels:
+ v, ok := value.([]string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExtraModels(v)
+ return nil
+ case channelmonitor.FieldGroupName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetGroupName(v)
+ return nil
+ case channelmonitor.FieldEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEnabled(v)
+ return nil
+ case channelmonitor.FieldIntervalSeconds:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIntervalSeconds(v)
+ return nil
+ case channelmonitor.FieldLastCheckedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLastCheckedAt(v)
+ return nil
+ case channelmonitor.FieldCreatedBy:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedBy(v)
+ return nil
+ case channelmonitor.FieldTemplateID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTemplateID(v)
+ return nil
+ case channelmonitor.FieldExtraHeaders:
+ v, ok := value.(map[string]string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExtraHeaders(v)
+ return nil
+ case channelmonitor.FieldBodyOverrideMode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBodyOverrideMode(v)
+ return nil
+ case channelmonitor.FieldBodyOverride:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBodyOverride(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitor field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *ChannelMonitorMutation) AddedFields() []string {
+ var fields []string
+ if m.addinterval_seconds != nil {
+ fields = append(fields, channelmonitor.FieldIntervalSeconds)
+ }
+ if m.addcreated_by != nil {
+ fields = append(fields, channelmonitor.FieldCreatedBy)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *ChannelMonitorMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitor.FieldIntervalSeconds:
+ return m.AddedIntervalSeconds()
+ case channelmonitor.FieldCreatedBy:
+ return m.AddedCreatedBy()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitor.FieldIntervalSeconds:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddIntervalSeconds(v)
+ return nil
+ case channelmonitor.FieldCreatedBy:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddCreatedBy(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitor numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *ChannelMonitorMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(channelmonitor.FieldGroupName) {
+ fields = append(fields, channelmonitor.FieldGroupName)
+ }
+ if m.FieldCleared(channelmonitor.FieldLastCheckedAt) {
+ fields = append(fields, channelmonitor.FieldLastCheckedAt)
+ }
+ if m.FieldCleared(channelmonitor.FieldTemplateID) {
+ fields = append(fields, channelmonitor.FieldTemplateID)
+ }
+ if m.FieldCleared(channelmonitor.FieldBodyOverride) {
+ fields = append(fields, channelmonitor.FieldBodyOverride)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *ChannelMonitorMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *ChannelMonitorMutation) ClearField(name string) error {
+ switch name {
+ case channelmonitor.FieldGroupName:
+ m.ClearGroupName()
+ return nil
+ case channelmonitor.FieldLastCheckedAt:
+ m.ClearLastCheckedAt()
+ return nil
+ case channelmonitor.FieldTemplateID:
+ m.ClearTemplateID()
+ return nil
+ case channelmonitor.FieldBodyOverride:
+ m.ClearBodyOverride()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitor nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *ChannelMonitorMutation) ResetField(name string) error {
+ switch name {
+ case channelmonitor.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case channelmonitor.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case channelmonitor.FieldName:
+ m.ResetName()
+ return nil
+ case channelmonitor.FieldProvider:
+ m.ResetProvider()
+ return nil
+ case channelmonitor.FieldEndpoint:
+ m.ResetEndpoint()
+ return nil
+ case channelmonitor.FieldAPIKeyEncrypted:
+ m.ResetAPIKeyEncrypted()
+ return nil
+ case channelmonitor.FieldPrimaryModel:
+ m.ResetPrimaryModel()
+ return nil
+ case channelmonitor.FieldExtraModels:
+ m.ResetExtraModels()
+ return nil
+ case channelmonitor.FieldGroupName:
+ m.ResetGroupName()
+ return nil
+ case channelmonitor.FieldEnabled:
+ m.ResetEnabled()
+ return nil
+ case channelmonitor.FieldIntervalSeconds:
+ m.ResetIntervalSeconds()
+ return nil
+ case channelmonitor.FieldLastCheckedAt:
+ m.ResetLastCheckedAt()
+ return nil
+ case channelmonitor.FieldCreatedBy:
+ m.ResetCreatedBy()
+ return nil
+ case channelmonitor.FieldTemplateID:
+ m.ResetTemplateID()
+ return nil
+ case channelmonitor.FieldExtraHeaders:
+ m.ResetExtraHeaders()
+ return nil
+ case channelmonitor.FieldBodyOverrideMode:
+ m.ResetBodyOverrideMode()
+ return nil
+ case channelmonitor.FieldBodyOverride:
+ m.ResetBodyOverride()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitor field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *ChannelMonitorMutation) AddedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.history != nil {
+ edges = append(edges, channelmonitor.EdgeHistory)
+ }
+ if m.daily_rollups != nil {
+ edges = append(edges, channelmonitor.EdgeDailyRollups)
+ }
+ if m.request_template != nil {
+ edges = append(edges, channelmonitor.EdgeRequestTemplate)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *ChannelMonitorMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitor.EdgeHistory:
+ ids := make([]ent.Value, 0, len(m.history))
+ for id := range m.history {
+ ids = append(ids, id)
+ }
+ return ids
+ case channelmonitor.EdgeDailyRollups:
+ ids := make([]ent.Value, 0, len(m.daily_rollups))
+ for id := range m.daily_rollups {
+ ids = append(ids, id)
+ }
+ return ids
+ case channelmonitor.EdgeRequestTemplate:
+ if id := m.request_template; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *ChannelMonitorMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.removedhistory != nil {
+ edges = append(edges, channelmonitor.EdgeHistory)
+ }
+ if m.removeddaily_rollups != nil {
+ edges = append(edges, channelmonitor.EdgeDailyRollups)
+ }
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *ChannelMonitorMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitor.EdgeHistory:
+ ids := make([]ent.Value, 0, len(m.removedhistory))
+ for id := range m.removedhistory {
+ ids = append(ids, id)
+ }
+ return ids
+ case channelmonitor.EdgeDailyRollups:
+ ids := make([]ent.Value, 0, len(m.removeddaily_rollups))
+ for id := range m.removeddaily_rollups {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *ChannelMonitorMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.clearedhistory {
+ edges = append(edges, channelmonitor.EdgeHistory)
+ }
+ if m.cleareddaily_rollups {
+ edges = append(edges, channelmonitor.EdgeDailyRollups)
+ }
+ if m.clearedrequest_template {
+ edges = append(edges, channelmonitor.EdgeRequestTemplate)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *ChannelMonitorMutation) EdgeCleared(name string) bool {
+ switch name {
+ case channelmonitor.EdgeHistory:
+ return m.clearedhistory
+ case channelmonitor.EdgeDailyRollups:
+ return m.cleareddaily_rollups
+ case channelmonitor.EdgeRequestTemplate:
+ return m.clearedrequest_template
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *ChannelMonitorMutation) ClearEdge(name string) error {
+ switch name {
+ case channelmonitor.EdgeRequestTemplate:
+ m.ClearRequestTemplate()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitor unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *ChannelMonitorMutation) ResetEdge(name string) error {
+ switch name {
+ case channelmonitor.EdgeHistory:
+ m.ResetHistory()
+ return nil
+ case channelmonitor.EdgeDailyRollups:
+ m.ResetDailyRollups()
+ return nil
+ case channelmonitor.EdgeRequestTemplate:
+ m.ResetRequestTemplate()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitor edge %s", name)
+}
+
+// ChannelMonitorDailyRollupMutation represents an operation that mutates the ChannelMonitorDailyRollup nodes in the graph.
+type ChannelMonitorDailyRollupMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ model *string
+ bucket_date *time.Time
+ total_checks *int
+ addtotal_checks *int
+ ok_count *int
+ addok_count *int
+ operational_count *int
+ addoperational_count *int
+ degraded_count *int
+ adddegraded_count *int
+ failed_count *int
+ addfailed_count *int
+ error_count *int
+ adderror_count *int
+ sum_latency_ms *int64
+ addsum_latency_ms *int64
+ count_latency *int
+ addcount_latency *int
+ sum_ping_latency_ms *int64
+ addsum_ping_latency_ms *int64
+ count_ping_latency *int
+ addcount_ping_latency *int
+ computed_at *time.Time
+ clearedFields map[string]struct{}
+ monitor *int64
+ clearedmonitor bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitorDailyRollup, error)
+ predicates []predicate.ChannelMonitorDailyRollup
+}
+
+var _ ent.Mutation = (*ChannelMonitorDailyRollupMutation)(nil)
+
+// channelmonitordailyrollupOption allows management of the mutation configuration using functional options.
+type channelmonitordailyrollupOption func(*ChannelMonitorDailyRollupMutation)
+
+// newChannelMonitorDailyRollupMutation creates new mutation for the ChannelMonitorDailyRollup entity.
+func newChannelMonitorDailyRollupMutation(c config, op Op, opts ...channelmonitordailyrollupOption) *ChannelMonitorDailyRollupMutation {
+ m := &ChannelMonitorDailyRollupMutation{
+ config: c,
+ op: op,
+ typ: TypeChannelMonitorDailyRollup,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withChannelMonitorDailyRollupID sets the ID field of the mutation.
+func withChannelMonitorDailyRollupID(id int64) channelmonitordailyrollupOption {
+ return func(m *ChannelMonitorDailyRollupMutation) {
+ var (
+ err error
+ once sync.Once
+ value *ChannelMonitorDailyRollup
+ )
+ m.oldValue = func(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().ChannelMonitorDailyRollup.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withChannelMonitorDailyRollup sets the old ChannelMonitorDailyRollup of the mutation.
+func withChannelMonitorDailyRollup(node *ChannelMonitorDailyRollup) channelmonitordailyrollupOption {
+ return func(m *ChannelMonitorDailyRollupMutation) {
+ m.oldValue = func(context.Context) (*ChannelMonitorDailyRollup, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m ChannelMonitorDailyRollupMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m ChannelMonitorDailyRollupMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *ChannelMonitorDailyRollupMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *ChannelMonitorDailyRollupMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().ChannelMonitorDailyRollup.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (m *ChannelMonitorDailyRollupMutation) SetMonitorID(i int64) {
+ m.monitor = &i
+}
+
+// MonitorID returns the value of the "monitor_id" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) MonitorID() (r int64, exists bool) {
+ v := m.monitor
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMonitorID returns the old "monitor_id" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldMonitorID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMonitorID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMonitorID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMonitorID: %w", err)
+ }
+ return oldValue.MonitorID, nil
+}
+
+// ResetMonitorID resets all changes to the "monitor_id" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetMonitorID() {
+ m.monitor = nil
+}
+
+// SetModel sets the "model" field.
+func (m *ChannelMonitorDailyRollupMutation) SetModel(s string) {
+ m.model = &s
+}
+
+// Model returns the value of the "model" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) Model() (r string, exists bool) {
+ v := m.model
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldModel returns the old "model" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldModel(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldModel is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldModel requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldModel: %w", err)
+ }
+ return oldValue.Model, nil
+}
+
+// ResetModel resets all changes to the "model" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetModel() {
+ m.model = nil
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (m *ChannelMonitorDailyRollupMutation) SetBucketDate(t time.Time) {
+ m.bucket_date = &t
+}
+
+// BucketDate returns the value of the "bucket_date" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) BucketDate() (r time.Time, exists bool) {
+ v := m.bucket_date
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBucketDate returns the old "bucket_date" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldBucketDate(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBucketDate is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBucketDate requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBucketDate: %w", err)
+ }
+ return oldValue.BucketDate, nil
+}
+
+// ResetBucketDate resets all changes to the "bucket_date" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetBucketDate() {
+ m.bucket_date = nil
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (m *ChannelMonitorDailyRollupMutation) SetTotalChecks(i int) {
+ m.total_checks = &i
+ m.addtotal_checks = nil
+}
+
+// TotalChecks returns the value of the "total_checks" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) TotalChecks() (r int, exists bool) {
+ v := m.total_checks
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotalChecks returns the old "total_checks" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldTotalChecks(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotalChecks is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotalChecks requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotalChecks: %w", err)
+ }
+ return oldValue.TotalChecks, nil
+}
+
+// AddTotalChecks adds i to the "total_checks" field.
+func (m *ChannelMonitorDailyRollupMutation) AddTotalChecks(i int) {
+ if m.addtotal_checks != nil {
+ *m.addtotal_checks += i
+ } else {
+ m.addtotal_checks = &i
+ }
+}
+
+// AddedTotalChecks returns the value that was added to the "total_checks" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedTotalChecks() (r int, exists bool) {
+ v := m.addtotal_checks
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetTotalChecks resets all changes to the "total_checks" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetTotalChecks() {
+ m.total_checks = nil
+ m.addtotal_checks = nil
+}
+
+// SetOkCount sets the "ok_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetOkCount(i int) {
+ m.ok_count = &i
+ m.addok_count = nil
+}
+
+// OkCount returns the value of the "ok_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) OkCount() (r int, exists bool) {
+ v := m.ok_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldOkCount returns the old "ok_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldOkCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldOkCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldOkCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldOkCount: %w", err)
+ }
+ return oldValue.OkCount, nil
+}
+
+// AddOkCount adds i to the "ok_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddOkCount(i int) {
+ if m.addok_count != nil {
+ *m.addok_count += i
+ } else {
+ m.addok_count = &i
+ }
+}
+
+// AddedOkCount returns the value that was added to the "ok_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedOkCount() (r int, exists bool) {
+ v := m.addok_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetOkCount resets all changes to the "ok_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetOkCount() {
+ m.ok_count = nil
+ m.addok_count = nil
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetOperationalCount(i int) {
+ m.operational_count = &i
+ m.addoperational_count = nil
+}
+
+// OperationalCount returns the value of the "operational_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) OperationalCount() (r int, exists bool) {
+ v := m.operational_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldOperationalCount returns the old "operational_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldOperationalCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldOperationalCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldOperationalCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldOperationalCount: %w", err)
+ }
+ return oldValue.OperationalCount, nil
+}
+
+// AddOperationalCount adds i to the "operational_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddOperationalCount(i int) {
+ if m.addoperational_count != nil {
+ *m.addoperational_count += i
+ } else {
+ m.addoperational_count = &i
+ }
+}
+
+// AddedOperationalCount returns the value that was added to the "operational_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedOperationalCount() (r int, exists bool) {
+ v := m.addoperational_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetOperationalCount resets all changes to the "operational_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetOperationalCount() {
+ m.operational_count = nil
+ m.addoperational_count = nil
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetDegradedCount(i int) {
+ m.degraded_count = &i
+ m.adddegraded_count = nil
+}
+
+// DegradedCount returns the value of the "degraded_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) DegradedCount() (r int, exists bool) {
+ v := m.degraded_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDegradedCount returns the old "degraded_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldDegradedCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDegradedCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDegradedCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDegradedCount: %w", err)
+ }
+ return oldValue.DegradedCount, nil
+}
+
+// AddDegradedCount adds i to the "degraded_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddDegradedCount(i int) {
+ if m.adddegraded_count != nil {
+ *m.adddegraded_count += i
+ } else {
+ m.adddegraded_count = &i
+ }
+}
+
+// AddedDegradedCount returns the value that was added to the "degraded_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedDegradedCount() (r int, exists bool) {
+ v := m.adddegraded_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetDegradedCount resets all changes to the "degraded_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetDegradedCount() {
+ m.degraded_count = nil
+ m.adddegraded_count = nil
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetFailedCount(i int) {
+ m.failed_count = &i
+ m.addfailed_count = nil
+}
+
+// FailedCount returns the value of the "failed_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) FailedCount() (r int, exists bool) {
+ v := m.failed_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldFailedCount returns the old "failed_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldFailedCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFailedCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFailedCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFailedCount: %w", err)
+ }
+ return oldValue.FailedCount, nil
+}
+
+// AddFailedCount adds i to the "failed_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddFailedCount(i int) {
+ if m.addfailed_count != nil {
+ *m.addfailed_count += i
+ } else {
+ m.addfailed_count = &i
+ }
+}
+
+// AddedFailedCount returns the value that was added to the "failed_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedFailedCount() (r int, exists bool) {
+ v := m.addfailed_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetFailedCount resets all changes to the "failed_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetFailedCount() {
+ m.failed_count = nil
+ m.addfailed_count = nil
+}
+
+// SetErrorCount sets the "error_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetErrorCount(i int) {
+ m.error_count = &i
+ m.adderror_count = nil
+}
+
+// ErrorCount returns the value of the "error_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) ErrorCount() (r int, exists bool) {
+ v := m.error_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldErrorCount returns the old "error_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldErrorCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldErrorCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldErrorCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldErrorCount: %w", err)
+ }
+ return oldValue.ErrorCount, nil
+}
+
+// AddErrorCount adds i to the "error_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddErrorCount(i int) {
+ if m.adderror_count != nil {
+ *m.adderror_count += i
+ } else {
+ m.adderror_count = &i
+ }
+}
+
+// AddedErrorCount returns the value that was added to the "error_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedErrorCount() (r int, exists bool) {
+ v := m.adderror_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetErrorCount resets all changes to the "error_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetErrorCount() {
+ m.error_count = nil
+ m.adderror_count = nil
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) SetSumLatencyMs(i int64) {
+ m.sum_latency_ms = &i
+ m.addsum_latency_ms = nil
+}
+
+// SumLatencyMs returns the value of the "sum_latency_ms" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) SumLatencyMs() (r int64, exists bool) {
+ v := m.sum_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSumLatencyMs returns the old "sum_latency_ms" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldSumLatencyMs(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSumLatencyMs is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSumLatencyMs requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSumLatencyMs: %w", err)
+ }
+ return oldValue.SumLatencyMs, nil
+}
+
+// AddSumLatencyMs adds i to the "sum_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) AddSumLatencyMs(i int64) {
+ if m.addsum_latency_ms != nil {
+ *m.addsum_latency_ms += i
+ } else {
+ m.addsum_latency_ms = &i
+ }
+}
+
+// AddedSumLatencyMs returns the value that was added to the "sum_latency_ms" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedSumLatencyMs() (r int64, exists bool) {
+ v := m.addsum_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSumLatencyMs resets all changes to the "sum_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetSumLatencyMs() {
+ m.sum_latency_ms = nil
+ m.addsum_latency_ms = nil
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) SetCountLatency(i int) {
+ m.count_latency = &i
+ m.addcount_latency = nil
+}
+
+// CountLatency returns the value of the "count_latency" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) CountLatency() (r int, exists bool) {
+ v := m.count_latency
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCountLatency returns the old "count_latency" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldCountLatency(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCountLatency is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCountLatency requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCountLatency: %w", err)
+ }
+ return oldValue.CountLatency, nil
+}
+
+// AddCountLatency adds i to the "count_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) AddCountLatency(i int) {
+ if m.addcount_latency != nil {
+ *m.addcount_latency += i
+ } else {
+ m.addcount_latency = &i
+ }
+}
+
+// AddedCountLatency returns the value that was added to the "count_latency" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedCountLatency() (r int, exists bool) {
+ v := m.addcount_latency
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetCountLatency resets all changes to the "count_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetCountLatency() {
+ m.count_latency = nil
+ m.addcount_latency = nil
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) SetSumPingLatencyMs(i int64) {
+ m.sum_ping_latency_ms = &i
+ m.addsum_ping_latency_ms = nil
+}
+
+// SumPingLatencyMs returns the value of the "sum_ping_latency_ms" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) SumPingLatencyMs() (r int64, exists bool) {
+ v := m.sum_ping_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSumPingLatencyMs returns the old "sum_ping_latency_ms" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldSumPingLatencyMs(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSumPingLatencyMs is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSumPingLatencyMs requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSumPingLatencyMs: %w", err)
+ }
+ return oldValue.SumPingLatencyMs, nil
+}
+
+// AddSumPingLatencyMs adds i to the "sum_ping_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) AddSumPingLatencyMs(i int64) {
+ if m.addsum_ping_latency_ms != nil {
+ *m.addsum_ping_latency_ms += i
+ } else {
+ m.addsum_ping_latency_ms = &i
+ }
+}
+
+// AddedSumPingLatencyMs returns the value that was added to the "sum_ping_latency_ms" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedSumPingLatencyMs() (r int64, exists bool) {
+ v := m.addsum_ping_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSumPingLatencyMs resets all changes to the "sum_ping_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetSumPingLatencyMs() {
+ m.sum_ping_latency_ms = nil
+ m.addsum_ping_latency_ms = nil
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) SetCountPingLatency(i int) {
+ m.count_ping_latency = &i
+ m.addcount_ping_latency = nil
+}
+
+// CountPingLatency returns the value of the "count_ping_latency" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) CountPingLatency() (r int, exists bool) {
+ v := m.count_ping_latency
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCountPingLatency returns the old "count_ping_latency" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldCountPingLatency(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCountPingLatency is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCountPingLatency requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCountPingLatency: %w", err)
+ }
+ return oldValue.CountPingLatency, nil
+}
+
+// AddCountPingLatency adds i to the "count_ping_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) AddCountPingLatency(i int) {
+ if m.addcount_ping_latency != nil {
+ *m.addcount_ping_latency += i
+ } else {
+ m.addcount_ping_latency = &i
+ }
+}
+
+// AddedCountPingLatency returns the value that was added to the "count_ping_latency" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedCountPingLatency() (r int, exists bool) {
+ v := m.addcount_ping_latency
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetCountPingLatency resets all changes to the "count_ping_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetCountPingLatency() {
+ m.count_ping_latency = nil
+ m.addcount_ping_latency = nil
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (m *ChannelMonitorDailyRollupMutation) SetComputedAt(t time.Time) {
+ m.computed_at = &t
+}
+
+// ComputedAt returns the value of the "computed_at" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) ComputedAt() (r time.Time, exists bool) {
+ v := m.computed_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldComputedAt returns the old "computed_at" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldComputedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldComputedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldComputedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldComputedAt: %w", err)
+ }
+ return oldValue.ComputedAt, nil
+}
+
+// ResetComputedAt resets all changes to the "computed_at" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetComputedAt() {
+ m.computed_at = nil
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (m *ChannelMonitorDailyRollupMutation) ClearMonitor() {
+ m.clearedmonitor = true
+ m.clearedFields[channelmonitordailyrollup.FieldMonitorID] = struct{}{}
+}
+
+// MonitorCleared reports if the "monitor" edge to the ChannelMonitor entity was cleared.
+func (m *ChannelMonitorDailyRollupMutation) MonitorCleared() bool {
+ return m.clearedmonitor
+}
+
+// MonitorIDs returns the "monitor" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// MonitorID instead. It exists only for internal usage by the builders.
+func (m *ChannelMonitorDailyRollupMutation) MonitorIDs() (ids []int64) {
+ if id := m.monitor; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetMonitor resets all changes to the "monitor" edge.
+func (m *ChannelMonitorDailyRollupMutation) ResetMonitor() {
+ m.monitor = nil
+ m.clearedmonitor = false
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupMutation builder.
+func (m *ChannelMonitorDailyRollupMutation) Where(ps ...predicate.ChannelMonitorDailyRollup) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the ChannelMonitorDailyRollupMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *ChannelMonitorDailyRollupMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.ChannelMonitorDailyRollup, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *ChannelMonitorDailyRollupMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *ChannelMonitorDailyRollupMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (ChannelMonitorDailyRollup).
+func (m *ChannelMonitorDailyRollupMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *ChannelMonitorDailyRollupMutation) Fields() []string {
+ fields := make([]string, 0, 14)
+ if m.monitor != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldMonitorID)
+ }
+ if m.model != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldModel)
+ }
+ if m.bucket_date != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldBucketDate)
+ }
+ if m.total_checks != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldTotalChecks)
+ }
+ if m.ok_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldOkCount)
+ }
+ if m.operational_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldOperationalCount)
+ }
+ if m.degraded_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldDegradedCount)
+ }
+ if m.failed_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldFailedCount)
+ }
+ if m.error_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldErrorCount)
+ }
+ if m.sum_latency_ms != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldSumLatencyMs)
+ }
+ if m.count_latency != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldCountLatency)
+ }
+ if m.sum_ping_latency_ms != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldSumPingLatencyMs)
+ }
+ if m.count_ping_latency != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldCountPingLatency)
+ }
+ if m.computed_at != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldComputedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *ChannelMonitorDailyRollupMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitordailyrollup.FieldMonitorID:
+ return m.MonitorID()
+ case channelmonitordailyrollup.FieldModel:
+ return m.Model()
+ case channelmonitordailyrollup.FieldBucketDate:
+ return m.BucketDate()
+ case channelmonitordailyrollup.FieldTotalChecks:
+ return m.TotalChecks()
+ case channelmonitordailyrollup.FieldOkCount:
+ return m.OkCount()
+ case channelmonitordailyrollup.FieldOperationalCount:
+ return m.OperationalCount()
+ case channelmonitordailyrollup.FieldDegradedCount:
+ return m.DegradedCount()
+ case channelmonitordailyrollup.FieldFailedCount:
+ return m.FailedCount()
+ case channelmonitordailyrollup.FieldErrorCount:
+ return m.ErrorCount()
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ return m.SumLatencyMs()
+ case channelmonitordailyrollup.FieldCountLatency:
+ return m.CountLatency()
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ return m.SumPingLatencyMs()
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ return m.CountPingLatency()
+ case channelmonitordailyrollup.FieldComputedAt:
+ return m.ComputedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *ChannelMonitorDailyRollupMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case channelmonitordailyrollup.FieldMonitorID:
+ return m.OldMonitorID(ctx)
+ case channelmonitordailyrollup.FieldModel:
+ return m.OldModel(ctx)
+ case channelmonitordailyrollup.FieldBucketDate:
+ return m.OldBucketDate(ctx)
+ case channelmonitordailyrollup.FieldTotalChecks:
+ return m.OldTotalChecks(ctx)
+ case channelmonitordailyrollup.FieldOkCount:
+ return m.OldOkCount(ctx)
+ case channelmonitordailyrollup.FieldOperationalCount:
+ return m.OldOperationalCount(ctx)
+ case channelmonitordailyrollup.FieldDegradedCount:
+ return m.OldDegradedCount(ctx)
+ case channelmonitordailyrollup.FieldFailedCount:
+ return m.OldFailedCount(ctx)
+ case channelmonitordailyrollup.FieldErrorCount:
+ return m.OldErrorCount(ctx)
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ return m.OldSumLatencyMs(ctx)
+ case channelmonitordailyrollup.FieldCountLatency:
+ return m.OldCountLatency(ctx)
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ return m.OldSumPingLatencyMs(ctx)
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ return m.OldCountPingLatency(ctx)
+ case channelmonitordailyrollup.FieldComputedAt:
+ return m.OldComputedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown ChannelMonitorDailyRollup field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorDailyRollupMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitordailyrollup.FieldMonitorID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMonitorID(v)
+ return nil
+ case channelmonitordailyrollup.FieldModel:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetModel(v)
+ return nil
+ case channelmonitordailyrollup.FieldBucketDate:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBucketDate(v)
+ return nil
+ case channelmonitordailyrollup.FieldTotalChecks:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotalChecks(v)
+ return nil
+ case channelmonitordailyrollup.FieldOkCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOkCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldOperationalCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOperationalCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldDegradedCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDegradedCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldFailedCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFailedCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldErrorCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetErrorCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSumLatencyMs(v)
+ return nil
+ case channelmonitordailyrollup.FieldCountLatency:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCountLatency(v)
+ return nil
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSumPingLatencyMs(v)
+ return nil
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCountPingLatency(v)
+ return nil
+ case channelmonitordailyrollup.FieldComputedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetComputedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedFields() []string {
+ var fields []string
+ if m.addtotal_checks != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldTotalChecks)
+ }
+ if m.addok_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldOkCount)
+ }
+ if m.addoperational_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldOperationalCount)
+ }
+ if m.adddegraded_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldDegradedCount)
+ }
+ if m.addfailed_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldFailedCount)
+ }
+ if m.adderror_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldErrorCount)
+ }
+ if m.addsum_latency_ms != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldSumLatencyMs)
+ }
+ if m.addcount_latency != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldCountLatency)
+ }
+ if m.addsum_ping_latency_ms != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldSumPingLatencyMs)
+ }
+ if m.addcount_ping_latency != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldCountPingLatency)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitordailyrollup.FieldTotalChecks:
+ return m.AddedTotalChecks()
+ case channelmonitordailyrollup.FieldOkCount:
+ return m.AddedOkCount()
+ case channelmonitordailyrollup.FieldOperationalCount:
+ return m.AddedOperationalCount()
+ case channelmonitordailyrollup.FieldDegradedCount:
+ return m.AddedDegradedCount()
+ case channelmonitordailyrollup.FieldFailedCount:
+ return m.AddedFailedCount()
+ case channelmonitordailyrollup.FieldErrorCount:
+ return m.AddedErrorCount()
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ return m.AddedSumLatencyMs()
+ case channelmonitordailyrollup.FieldCountLatency:
+ return m.AddedCountLatency()
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ return m.AddedSumPingLatencyMs()
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ return m.AddedCountPingLatency()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorDailyRollupMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitordailyrollup.FieldTotalChecks:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddTotalChecks(v)
+ return nil
+ case channelmonitordailyrollup.FieldOkCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddOkCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldOperationalCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddOperationalCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldDegradedCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddDegradedCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldFailedCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddFailedCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldErrorCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddErrorCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSumLatencyMs(v)
+ return nil
+ case channelmonitordailyrollup.FieldCountLatency:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddCountLatency(v)
+ return nil
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSumPingLatencyMs(v)
+ return nil
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddCountPingLatency(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *ChannelMonitorDailyRollupMutation) ClearedFields() []string {
+ return nil
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) ClearField(name string) error {
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) ResetField(name string) error {
+ switch name {
+ case channelmonitordailyrollup.FieldMonitorID:
+ m.ResetMonitorID()
+ return nil
+ case channelmonitordailyrollup.FieldModel:
+ m.ResetModel()
+ return nil
+ case channelmonitordailyrollup.FieldBucketDate:
+ m.ResetBucketDate()
+ return nil
+ case channelmonitordailyrollup.FieldTotalChecks:
+ m.ResetTotalChecks()
+ return nil
+ case channelmonitordailyrollup.FieldOkCount:
+ m.ResetOkCount()
+ return nil
+ case channelmonitordailyrollup.FieldOperationalCount:
+ m.ResetOperationalCount()
+ return nil
+ case channelmonitordailyrollup.FieldDegradedCount:
+ m.ResetDegradedCount()
+ return nil
+ case channelmonitordailyrollup.FieldFailedCount:
+ m.ResetFailedCount()
+ return nil
+ case channelmonitordailyrollup.FieldErrorCount:
+ m.ResetErrorCount()
+ return nil
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ m.ResetSumLatencyMs()
+ return nil
+ case channelmonitordailyrollup.FieldCountLatency:
+ m.ResetCountLatency()
+ return nil
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ m.ResetSumPingLatencyMs()
+ return nil
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ m.ResetCountPingLatency()
+ return nil
+ case channelmonitordailyrollup.FieldComputedAt:
+ m.ResetComputedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.monitor != nil {
+ edges = append(edges, channelmonitordailyrollup.EdgeMonitor)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitordailyrollup.EdgeMonitor:
+ if id := m.monitor; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedmonitor {
+ edges = append(edges, channelmonitordailyrollup.EdgeMonitor)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) EdgeCleared(name string) bool {
+ switch name {
+ case channelmonitordailyrollup.EdgeMonitor:
+ return m.clearedmonitor
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) ClearEdge(name string) error {
+ switch name {
+ case channelmonitordailyrollup.EdgeMonitor:
+ m.ClearMonitor()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) ResetEdge(name string) error {
+ switch name {
+ case channelmonitordailyrollup.EdgeMonitor:
+ m.ResetMonitor()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup edge %s", name)
+}
+
+// ChannelMonitorHistoryMutation represents an operation that mutates the ChannelMonitorHistory nodes in the graph.
+type ChannelMonitorHistoryMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ model *string
+ status *channelmonitorhistory.Status
+ latency_ms *int
+ addlatency_ms *int
+ ping_latency_ms *int
+ addping_latency_ms *int
+ message *string
+ checked_at *time.Time
+ clearedFields map[string]struct{}
+ monitor *int64
+ clearedmonitor bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitorHistory, error)
+ predicates []predicate.ChannelMonitorHistory
+}
+
+var _ ent.Mutation = (*ChannelMonitorHistoryMutation)(nil)
+
+// channelmonitorhistoryOption allows management of the mutation configuration using functional options.
+type channelmonitorhistoryOption func(*ChannelMonitorHistoryMutation)
+
+// newChannelMonitorHistoryMutation creates new mutation for the ChannelMonitorHistory entity.
+func newChannelMonitorHistoryMutation(c config, op Op, opts ...channelmonitorhistoryOption) *ChannelMonitorHistoryMutation {
+ m := &ChannelMonitorHistoryMutation{
+ config: c,
+ op: op,
+ typ: TypeChannelMonitorHistory,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withChannelMonitorHistoryID sets the ID field of the mutation.
+func withChannelMonitorHistoryID(id int64) channelmonitorhistoryOption {
+ return func(m *ChannelMonitorHistoryMutation) {
+ var (
+ err error
+ once sync.Once
+ value *ChannelMonitorHistory
+ )
+ m.oldValue = func(ctx context.Context) (*ChannelMonitorHistory, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().ChannelMonitorHistory.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withChannelMonitorHistory sets the old ChannelMonitorHistory of the mutation.
+func withChannelMonitorHistory(node *ChannelMonitorHistory) channelmonitorhistoryOption {
+ return func(m *ChannelMonitorHistoryMutation) {
+ m.oldValue = func(context.Context) (*ChannelMonitorHistory, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m ChannelMonitorHistoryMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m ChannelMonitorHistoryMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *ChannelMonitorHistoryMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *ChannelMonitorHistoryMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().ChannelMonitorHistory.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (m *ChannelMonitorHistoryMutation) SetMonitorID(i int64) {
+ m.monitor = &i
+}
+
+// MonitorID returns the value of the "monitor_id" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) MonitorID() (r int64, exists bool) {
+ v := m.monitor
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMonitorID returns the old "monitor_id" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldMonitorID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMonitorID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMonitorID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMonitorID: %w", err)
+ }
+ return oldValue.MonitorID, nil
+}
+
+// ResetMonitorID resets all changes to the "monitor_id" field.
+func (m *ChannelMonitorHistoryMutation) ResetMonitorID() {
+ m.monitor = nil
+}
+
+// SetModel sets the "model" field.
+func (m *ChannelMonitorHistoryMutation) SetModel(s string) {
+ m.model = &s
+}
+
+// Model returns the value of the "model" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) Model() (r string, exists bool) {
+ v := m.model
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldModel returns the old "model" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldModel(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldModel is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldModel requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldModel: %w", err)
+ }
+ return oldValue.Model, nil
+}
+
+// ResetModel resets all changes to the "model" field.
+func (m *ChannelMonitorHistoryMutation) ResetModel() {
+ m.model = nil
+}
+
+// SetStatus sets the "status" field.
+func (m *ChannelMonitorHistoryMutation) SetStatus(c channelmonitorhistory.Status) {
+ m.status = &c
+}
+
+// Status returns the value of the "status" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) Status() (r channelmonitorhistory.Status, exists bool) {
+ v := m.status
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldStatus returns the old "status" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldStatus(ctx context.Context) (v channelmonitorhistory.Status, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStatus is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStatus requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStatus: %w", err)
+ }
+ return oldValue.Status, nil
+}
+
+// ResetStatus resets all changes to the "status" field.
+func (m *ChannelMonitorHistoryMutation) ResetStatus() {
+ m.status = nil
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) SetLatencyMs(i int) {
+ m.latency_ms = &i
+ m.addlatency_ms = nil
+}
+
+// LatencyMs returns the value of the "latency_ms" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) LatencyMs() (r int, exists bool) {
+ v := m.latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLatencyMs returns the old "latency_ms" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldLatencyMs(ctx context.Context) (v *int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLatencyMs is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLatencyMs requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLatencyMs: %w", err)
+ }
+ return oldValue.LatencyMs, nil
+}
+
+// AddLatencyMs adds i to the "latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) AddLatencyMs(i int) {
+ if m.addlatency_ms != nil {
+ *m.addlatency_ms += i
+ } else {
+ m.addlatency_ms = &i
+ }
+}
+
+// AddedLatencyMs returns the value that was added to the "latency_ms" field in this mutation.
+func (m *ChannelMonitorHistoryMutation) AddedLatencyMs() (r int, exists bool) {
+ v := m.addlatency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) ClearLatencyMs() {
+ m.latency_ms = nil
+ m.addlatency_ms = nil
+ m.clearedFields[channelmonitorhistory.FieldLatencyMs] = struct{}{}
+}
+
+// LatencyMsCleared returns if the "latency_ms" field was cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) LatencyMsCleared() bool {
+ _, ok := m.clearedFields[channelmonitorhistory.FieldLatencyMs]
+ return ok
+}
+
+// ResetLatencyMs resets all changes to the "latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) ResetLatencyMs() {
+ m.latency_ms = nil
+ m.addlatency_ms = nil
+ delete(m.clearedFields, channelmonitorhistory.FieldLatencyMs)
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) SetPingLatencyMs(i int) {
+ m.ping_latency_ms = &i
+ m.addping_latency_ms = nil
+}
+
+// PingLatencyMs returns the value of the "ping_latency_ms" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) PingLatencyMs() (r int, exists bool) {
+ v := m.ping_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPingLatencyMs returns the old "ping_latency_ms" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldPingLatencyMs(ctx context.Context) (v *int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPingLatencyMs is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPingLatencyMs requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPingLatencyMs: %w", err)
+ }
+ return oldValue.PingLatencyMs, nil
+}
+
+// AddPingLatencyMs adds i to the "ping_latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) AddPingLatencyMs(i int) {
+ if m.addping_latency_ms != nil {
+ *m.addping_latency_ms += i
+ } else {
+ m.addping_latency_ms = &i
+ }
+}
+
+// AddedPingLatencyMs returns the value that was added to the "ping_latency_ms" field in this mutation.
+func (m *ChannelMonitorHistoryMutation) AddedPingLatencyMs() (r int, exists bool) {
+ v := m.addping_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) ClearPingLatencyMs() {
+ m.ping_latency_ms = nil
+ m.addping_latency_ms = nil
+ m.clearedFields[channelmonitorhistory.FieldPingLatencyMs] = struct{}{}
+}
+
+// PingLatencyMsCleared returns if the "ping_latency_ms" field was cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) PingLatencyMsCleared() bool {
+ _, ok := m.clearedFields[channelmonitorhistory.FieldPingLatencyMs]
+ return ok
+}
+
+// ResetPingLatencyMs resets all changes to the "ping_latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) ResetPingLatencyMs() {
+ m.ping_latency_ms = nil
+ m.addping_latency_ms = nil
+ delete(m.clearedFields, channelmonitorhistory.FieldPingLatencyMs)
+}
+
+// SetMessage sets the "message" field.
+func (m *ChannelMonitorHistoryMutation) SetMessage(s string) {
+ m.message = &s
+}
+
+// Message returns the value of the "message" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) Message() (r string, exists bool) {
+ v := m.message
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMessage returns the old "message" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldMessage(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMessage is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMessage requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMessage: %w", err)
+ }
+ return oldValue.Message, nil
+}
+
+// ClearMessage clears the value of the "message" field.
+func (m *ChannelMonitorHistoryMutation) ClearMessage() {
+ m.message = nil
+ m.clearedFields[channelmonitorhistory.FieldMessage] = struct{}{}
+}
+
+// MessageCleared returns if the "message" field was cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) MessageCleared() bool {
+ _, ok := m.clearedFields[channelmonitorhistory.FieldMessage]
+ return ok
+}
+
+// ResetMessage resets all changes to the "message" field.
+func (m *ChannelMonitorHistoryMutation) ResetMessage() {
+ m.message = nil
+ delete(m.clearedFields, channelmonitorhistory.FieldMessage)
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (m *ChannelMonitorHistoryMutation) SetCheckedAt(t time.Time) {
+ m.checked_at = &t
+}
+
+// CheckedAt returns the value of the "checked_at" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) CheckedAt() (r time.Time, exists bool) {
+ v := m.checked_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCheckedAt returns the old "checked_at" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldCheckedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCheckedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCheckedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCheckedAt: %w", err)
+ }
+ return oldValue.CheckedAt, nil
+}
+
+// ResetCheckedAt resets all changes to the "checked_at" field.
+func (m *ChannelMonitorHistoryMutation) ResetCheckedAt() {
+ m.checked_at = nil
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (m *ChannelMonitorHistoryMutation) ClearMonitor() {
+ m.clearedmonitor = true
+ m.clearedFields[channelmonitorhistory.FieldMonitorID] = struct{}{}
+}
+
+// MonitorCleared reports if the "monitor" edge to the ChannelMonitor entity was cleared.
+func (m *ChannelMonitorHistoryMutation) MonitorCleared() bool {
+ return m.clearedmonitor
+}
+
+// MonitorIDs returns the "monitor" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// MonitorID instead. It exists only for internal usage by the builders.
+func (m *ChannelMonitorHistoryMutation) MonitorIDs() (ids []int64) {
+ if id := m.monitor; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetMonitor resets all changes to the "monitor" edge.
+func (m *ChannelMonitorHistoryMutation) ResetMonitor() {
+ m.monitor = nil
+ m.clearedmonitor = false
+}
+
+// Where appends a list predicates to the ChannelMonitorHistoryMutation builder.
+func (m *ChannelMonitorHistoryMutation) Where(ps ...predicate.ChannelMonitorHistory) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the ChannelMonitorHistoryMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *ChannelMonitorHistoryMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.ChannelMonitorHistory, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *ChannelMonitorHistoryMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *ChannelMonitorHistoryMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (ChannelMonitorHistory).
+func (m *ChannelMonitorHistoryMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *ChannelMonitorHistoryMutation) Fields() []string {
+ fields := make([]string, 0, 7)
+ if m.monitor != nil {
+ fields = append(fields, channelmonitorhistory.FieldMonitorID)
+ }
+ if m.model != nil {
+ fields = append(fields, channelmonitorhistory.FieldModel)
+ }
+ if m.status != nil {
+ fields = append(fields, channelmonitorhistory.FieldStatus)
+ }
+ if m.latency_ms != nil {
+ fields = append(fields, channelmonitorhistory.FieldLatencyMs)
+ }
+ if m.ping_latency_ms != nil {
+ fields = append(fields, channelmonitorhistory.FieldPingLatencyMs)
+ }
+ if m.message != nil {
+ fields = append(fields, channelmonitorhistory.FieldMessage)
+ }
+ if m.checked_at != nil {
+ fields = append(fields, channelmonitorhistory.FieldCheckedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *ChannelMonitorHistoryMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitorhistory.FieldMonitorID:
+ return m.MonitorID()
+ case channelmonitorhistory.FieldModel:
+ return m.Model()
+ case channelmonitorhistory.FieldStatus:
+ return m.Status()
+ case channelmonitorhistory.FieldLatencyMs:
+ return m.LatencyMs()
+ case channelmonitorhistory.FieldPingLatencyMs:
+ return m.PingLatencyMs()
+ case channelmonitorhistory.FieldMessage:
+ return m.Message()
+ case channelmonitorhistory.FieldCheckedAt:
+ return m.CheckedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *ChannelMonitorHistoryMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case channelmonitorhistory.FieldMonitorID:
+ return m.OldMonitorID(ctx)
+ case channelmonitorhistory.FieldModel:
+ return m.OldModel(ctx)
+ case channelmonitorhistory.FieldStatus:
+ return m.OldStatus(ctx)
+ case channelmonitorhistory.FieldLatencyMs:
+ return m.OldLatencyMs(ctx)
+ case channelmonitorhistory.FieldPingLatencyMs:
+ return m.OldPingLatencyMs(ctx)
+ case channelmonitorhistory.FieldMessage:
+ return m.OldMessage(ctx)
+ case channelmonitorhistory.FieldCheckedAt:
+ return m.OldCheckedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown ChannelMonitorHistory field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorHistoryMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitorhistory.FieldMonitorID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMonitorID(v)
+ return nil
+ case channelmonitorhistory.FieldModel:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetModel(v)
+ return nil
+ case channelmonitorhistory.FieldStatus:
+ v, ok := value.(channelmonitorhistory.Status)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetStatus(v)
+ return nil
+ case channelmonitorhistory.FieldLatencyMs:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLatencyMs(v)
+ return nil
+ case channelmonitorhistory.FieldPingLatencyMs:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPingLatencyMs(v)
+ return nil
+ case channelmonitorhistory.FieldMessage:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMessage(v)
+ return nil
+ case channelmonitorhistory.FieldCheckedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCheckedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorHistory field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *ChannelMonitorHistoryMutation) AddedFields() []string {
+ var fields []string
+ if m.addlatency_ms != nil {
+ fields = append(fields, channelmonitorhistory.FieldLatencyMs)
+ }
+ if m.addping_latency_ms != nil {
+ fields = append(fields, channelmonitorhistory.FieldPingLatencyMs)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *ChannelMonitorHistoryMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitorhistory.FieldLatencyMs:
+ return m.AddedLatencyMs()
+ case channelmonitorhistory.FieldPingLatencyMs:
+ return m.AddedPingLatencyMs()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorHistoryMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitorhistory.FieldLatencyMs:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddLatencyMs(v)
+ return nil
+ case channelmonitorhistory.FieldPingLatencyMs:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddPingLatencyMs(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorHistory numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *ChannelMonitorHistoryMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(channelmonitorhistory.FieldLatencyMs) {
+ fields = append(fields, channelmonitorhistory.FieldLatencyMs)
+ }
+ if m.FieldCleared(channelmonitorhistory.FieldPingLatencyMs) {
+ fields = append(fields, channelmonitorhistory.FieldPingLatencyMs)
+ }
+ if m.FieldCleared(channelmonitorhistory.FieldMessage) {
+ fields = append(fields, channelmonitorhistory.FieldMessage)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *ChannelMonitorHistoryMutation) ClearField(name string) error {
+ switch name {
+ case channelmonitorhistory.FieldLatencyMs:
+ m.ClearLatencyMs()
+ return nil
+ case channelmonitorhistory.FieldPingLatencyMs:
+ m.ClearPingLatencyMs()
+ return nil
+ case channelmonitorhistory.FieldMessage:
+ m.ClearMessage()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorHistory nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *ChannelMonitorHistoryMutation) ResetField(name string) error {
+ switch name {
+ case channelmonitorhistory.FieldMonitorID:
+ m.ResetMonitorID()
+ return nil
+ case channelmonitorhistory.FieldModel:
+ m.ResetModel()
+ return nil
+ case channelmonitorhistory.FieldStatus:
+ m.ResetStatus()
+ return nil
+ case channelmonitorhistory.FieldLatencyMs:
+ m.ResetLatencyMs()
+ return nil
+ case channelmonitorhistory.FieldPingLatencyMs:
+ m.ResetPingLatencyMs()
+ return nil
+ case channelmonitorhistory.FieldMessage:
+ m.ResetMessage()
+ return nil
+ case channelmonitorhistory.FieldCheckedAt:
+ m.ResetCheckedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorHistory field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *ChannelMonitorHistoryMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.monitor != nil {
+ edges = append(edges, channelmonitorhistory.EdgeMonitor)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *ChannelMonitorHistoryMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitorhistory.EdgeMonitor:
+ if id := m.monitor; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *ChannelMonitorHistoryMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *ChannelMonitorHistoryMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedmonitor {
+ edges = append(edges, channelmonitorhistory.EdgeMonitor)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) EdgeCleared(name string) bool {
+ switch name {
+ case channelmonitorhistory.EdgeMonitor:
+ return m.clearedmonitor
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *ChannelMonitorHistoryMutation) ClearEdge(name string) error {
+ switch name {
+ case channelmonitorhistory.EdgeMonitor:
+ m.ClearMonitor()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorHistory unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *ChannelMonitorHistoryMutation) ResetEdge(name string) error {
+ switch name {
+ case channelmonitorhistory.EdgeMonitor:
+ m.ResetMonitor()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorHistory edge %s", name)
+}
+
+// ChannelMonitorRequestTemplateMutation represents an operation that mutates the ChannelMonitorRequestTemplate nodes in the graph.
+type ChannelMonitorRequestTemplateMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ name *string
+ provider *channelmonitorrequesttemplate.Provider
+ description *string
+ extra_headers *map[string]string
+ body_override_mode *string
+ body_override *map[string]interface{}
+ clearedFields map[string]struct{}
+ monitors map[int64]struct{}
+ removedmonitors map[int64]struct{}
+ clearedmonitors bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitorRequestTemplate, error)
+ predicates []predicate.ChannelMonitorRequestTemplate
+}
+
+var _ ent.Mutation = (*ChannelMonitorRequestTemplateMutation)(nil)
+
+// channelmonitorrequesttemplateOption allows management of the mutation configuration using functional options.
+type channelmonitorrequesttemplateOption func(*ChannelMonitorRequestTemplateMutation)
+
+// newChannelMonitorRequestTemplateMutation creates new mutation for the ChannelMonitorRequestTemplate entity.
+func newChannelMonitorRequestTemplateMutation(c config, op Op, opts ...channelmonitorrequesttemplateOption) *ChannelMonitorRequestTemplateMutation {
+ m := &ChannelMonitorRequestTemplateMutation{
+ config: c,
+ op: op,
+ typ: TypeChannelMonitorRequestTemplate,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withChannelMonitorRequestTemplateID sets the ID field of the mutation.
+func withChannelMonitorRequestTemplateID(id int64) channelmonitorrequesttemplateOption {
+ return func(m *ChannelMonitorRequestTemplateMutation) {
+ var (
+ err error
+ once sync.Once
+ value *ChannelMonitorRequestTemplate
+ )
+ m.oldValue = func(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().ChannelMonitorRequestTemplate.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withChannelMonitorRequestTemplate sets the old ChannelMonitorRequestTemplate of the mutation.
+func withChannelMonitorRequestTemplate(node *ChannelMonitorRequestTemplate) channelmonitorrequesttemplateOption {
+ return func(m *ChannelMonitorRequestTemplateMutation) {
+ m.oldValue = func(context.Context) (*ChannelMonitorRequestTemplate, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m ChannelMonitorRequestTemplateMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m ChannelMonitorRequestTemplateMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *ChannelMonitorRequestTemplateMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().ChannelMonitorRequestTemplate.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetName sets the "name" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetName(s string) {
+ m.name = &s
+}
+
+// Name returns the value of the "name" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) Name() (r string, exists bool) {
+ v := m.name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldName returns the old "name" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldName: %w", err)
+ }
+ return oldValue.Name, nil
+}
+
+// ResetName resets all changes to the "name" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetName() {
+ m.name = nil
+}
+
+// SetProvider sets the "provider" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetProvider(c channelmonitorrequesttemplate.Provider) {
+ m.provider = &c
+}
+
+// Provider returns the value of the "provider" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) Provider() (r channelmonitorrequesttemplate.Provider, exists bool) {
+ v := m.provider
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProvider returns the old "provider" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldProvider(ctx context.Context) (v channelmonitorrequesttemplate.Provider, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProvider is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProvider requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProvider: %w", err)
+ }
+ return oldValue.Provider, nil
+}
+
+// ResetProvider resets all changes to the "provider" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetProvider() {
+ m.provider = nil
+}
+
+// SetDescription sets the "description" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetDescription(s string) {
+ m.description = &s
+}
+
+// Description returns the value of the "description" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) Description() (r string, exists bool) {
+ v := m.description
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDescription returns the old "description" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldDescription(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDescription is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDescription requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDescription: %w", err)
+ }
+ return oldValue.Description, nil
+}
+
+// ClearDescription clears the value of the "description" field.
+func (m *ChannelMonitorRequestTemplateMutation) ClearDescription() {
+ m.description = nil
+ m.clearedFields[channelmonitorrequesttemplate.FieldDescription] = struct{}{}
+}
+
+// DescriptionCleared returns if the "description" field was cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) DescriptionCleared() bool {
+ _, ok := m.clearedFields[channelmonitorrequesttemplate.FieldDescription]
+ return ok
+}
+
+// ResetDescription resets all changes to the "description" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetDescription() {
+ m.description = nil
+ delete(m.clearedFields, channelmonitorrequesttemplate.FieldDescription)
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetExtraHeaders(value map[string]string) {
+ m.extra_headers = &value
+}
+
+// ExtraHeaders returns the value of the "extra_headers" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) ExtraHeaders() (r map[string]string, exists bool) {
+ v := m.extra_headers
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExtraHeaders returns the old "extra_headers" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldExtraHeaders(ctx context.Context) (v map[string]string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExtraHeaders is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExtraHeaders requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExtraHeaders: %w", err)
+ }
+ return oldValue.ExtraHeaders, nil
+}
+
+// ResetExtraHeaders resets all changes to the "extra_headers" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetExtraHeaders() {
+ m.extra_headers = nil
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetBodyOverrideMode(s string) {
+ m.body_override_mode = &s
+}
+
+// BodyOverrideMode returns the value of the "body_override_mode" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) BodyOverrideMode() (r string, exists bool) {
+ v := m.body_override_mode
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBodyOverrideMode returns the old "body_override_mode" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldBodyOverrideMode(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBodyOverrideMode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBodyOverrideMode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBodyOverrideMode: %w", err)
+ }
+ return oldValue.BodyOverrideMode, nil
+}
+
+// ResetBodyOverrideMode resets all changes to the "body_override_mode" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetBodyOverrideMode() {
+ m.body_override_mode = nil
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetBodyOverride(value map[string]interface{}) {
+ m.body_override = &value
+}
+
+// BodyOverride returns the value of the "body_override" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) BodyOverride() (r map[string]interface{}, exists bool) {
+ v := m.body_override
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBodyOverride returns the old "body_override" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldBodyOverride(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBodyOverride is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBodyOverride requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBodyOverride: %w", err)
+ }
+ return oldValue.BodyOverride, nil
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (m *ChannelMonitorRequestTemplateMutation) ClearBodyOverride() {
+ m.body_override = nil
+ m.clearedFields[channelmonitorrequesttemplate.FieldBodyOverride] = struct{}{}
+}
+
+// BodyOverrideCleared returns if the "body_override" field was cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) BodyOverrideCleared() bool {
+ _, ok := m.clearedFields[channelmonitorrequesttemplate.FieldBodyOverride]
+ return ok
+}
+
+// ResetBodyOverride resets all changes to the "body_override" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetBodyOverride() {
+ m.body_override = nil
+ delete(m.clearedFields, channelmonitorrequesttemplate.FieldBodyOverride)
+}
+
+// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by ids.
+func (m *ChannelMonitorRequestTemplateMutation) AddMonitorIDs(ids ...int64) {
+ if m.monitors == nil {
+ m.monitors = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.monitors[ids[i]] = struct{}{}
+ }
+}
+
+// ClearMonitors clears the "monitors" edge to the ChannelMonitor entity.
+func (m *ChannelMonitorRequestTemplateMutation) ClearMonitors() {
+ m.clearedmonitors = true
+}
+
+// MonitorsCleared reports if the "monitors" edge to the ChannelMonitor entity was cleared.
+func (m *ChannelMonitorRequestTemplateMutation) MonitorsCleared() bool {
+ return m.clearedmonitors
+}
+
+// RemoveMonitorIDs removes the "monitors" edge to the ChannelMonitor entity by IDs.
+func (m *ChannelMonitorRequestTemplateMutation) RemoveMonitorIDs(ids ...int64) {
+ if m.removedmonitors == nil {
+ m.removedmonitors = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.monitors, ids[i])
+ m.removedmonitors[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedMonitors returns the removed IDs of the "monitors" edge to the ChannelMonitor entity.
+func (m *ChannelMonitorRequestTemplateMutation) RemovedMonitorsIDs() (ids []int64) {
+ for id := range m.removedmonitors {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// MonitorsIDs returns the "monitors" edge IDs in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) MonitorsIDs() (ids []int64) {
+ for id := range m.monitors {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetMonitors resets all changes to the "monitors" edge.
+func (m *ChannelMonitorRequestTemplateMutation) ResetMonitors() {
+ m.monitors = nil
+ m.clearedmonitors = false
+ m.removedmonitors = nil
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateMutation builder.
+func (m *ChannelMonitorRequestTemplateMutation) Where(ps ...predicate.ChannelMonitorRequestTemplate) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the ChannelMonitorRequestTemplateMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *ChannelMonitorRequestTemplateMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.ChannelMonitorRequestTemplate, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *ChannelMonitorRequestTemplateMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *ChannelMonitorRequestTemplateMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (ChannelMonitorRequestTemplate).
+func (m *ChannelMonitorRequestTemplateMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *ChannelMonitorRequestTemplateMutation) Fields() []string {
+ fields := make([]string, 0, 8)
+ if m.created_at != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldUpdatedAt)
+ }
+ if m.name != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldName)
+ }
+ if m.provider != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldProvider)
+ }
+ if m.description != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldDescription)
+ }
+ if m.extra_headers != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldExtraHeaders)
+ }
+ if m.body_override_mode != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldBodyOverrideMode)
+ }
+ if m.body_override != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldBodyOverride)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *ChannelMonitorRequestTemplateMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ return m.CreatedAt()
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case channelmonitorrequesttemplate.FieldName:
+ return m.Name()
+ case channelmonitorrequesttemplate.FieldProvider:
+ return m.Provider()
+ case channelmonitorrequesttemplate.FieldDescription:
+ return m.Description()
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ return m.ExtraHeaders()
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ return m.BodyOverrideMode()
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ return m.BodyOverride()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *ChannelMonitorRequestTemplateMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case channelmonitorrequesttemplate.FieldName:
+ return m.OldName(ctx)
+ case channelmonitorrequesttemplate.FieldProvider:
+ return m.OldProvider(ctx)
+ case channelmonitorrequesttemplate.FieldDescription:
+ return m.OldDescription(ctx)
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ return m.OldExtraHeaders(ctx)
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ return m.OldBodyOverrideMode(ctx)
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ return m.OldBodyOverride(ctx)
+ }
+ return nil, fmt.Errorf("unknown ChannelMonitorRequestTemplate field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorRequestTemplateMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetName(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldProvider:
+ v, ok := value.(channelmonitorrequesttemplate.Provider)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProvider(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldDescription:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDescription(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ v, ok := value.(map[string]string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExtraHeaders(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBodyOverrideMode(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBodyOverride(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) AddedFields() []string {
+ return nil
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) AddedField(name string) (ent.Value, bool) {
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorRequestTemplateMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *ChannelMonitorRequestTemplateMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(channelmonitorrequesttemplate.FieldDescription) {
+ fields = append(fields, channelmonitorrequesttemplate.FieldDescription)
+ }
+ if m.FieldCleared(channelmonitorrequesttemplate.FieldBodyOverride) {
+ fields = append(fields, channelmonitorrequesttemplate.FieldBodyOverride)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) ClearField(name string) error {
+ switch name {
+ case channelmonitorrequesttemplate.FieldDescription:
+ m.ClearDescription()
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ m.ClearBodyOverride()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) ResetField(name string) error {
+ switch name {
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case channelmonitorrequesttemplate.FieldName:
+ m.ResetName()
+ return nil
+ case channelmonitorrequesttemplate.FieldProvider:
+ m.ResetProvider()
+ return nil
+ case channelmonitorrequesttemplate.FieldDescription:
+ m.ResetDescription()
+ return nil
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ m.ResetExtraHeaders()
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ m.ResetBodyOverrideMode()
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ m.ResetBodyOverride()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.monitors != nil {
+ edges = append(edges, channelmonitorrequesttemplate.EdgeMonitors)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitorrequesttemplate.EdgeMonitors:
+ ids := make([]ent.Value, 0, len(m.monitors))
+ for id := range m.monitors {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.removedmonitors != nil {
+ edges = append(edges, channelmonitorrequesttemplate.EdgeMonitors)
+ }
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitorrequesttemplate.EdgeMonitors:
+ ids := make([]ent.Value, 0, len(m.removedmonitors))
+ for id := range m.removedmonitors {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedmonitors {
+ edges = append(edges, channelmonitorrequesttemplate.EdgeMonitors)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) EdgeCleared(name string) bool {
+ switch name {
+ case channelmonitorrequesttemplate.EdgeMonitors:
+ return m.clearedmonitors
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) ClearEdge(name string) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) ResetEdge(name string) error {
+ switch name {
+ case channelmonitorrequesttemplate.EdgeMonitors:
+ m.ResetMonitors()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate edge %s", name)
+}
+
// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph.
type ErrorPassthroughRuleMutation struct {
config
@@ -8255,6 +14787,8 @@ type GroupMutation struct {
require_privacy_set *bool
default_mapped_model *string
messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig
+ rpm_limit *int
+ addrpm_limit *int
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -9843,6 +16377,62 @@ func (m *GroupMutation) ResetMessagesDispatchModelConfig() {
m.messages_dispatch_model_config = nil
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (m *GroupMutation) SetRpmLimit(i int) {
+ m.rpm_limit = &i
+ m.addrpm_limit = nil
+}
+
+// RpmLimit returns the value of the "rpm_limit" field in the mutation.
+func (m *GroupMutation) RpmLimit() (r int, exists bool) {
+ v := m.rpm_limit
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRpmLimit returns the old "rpm_limit" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldRpmLimit(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRpmLimit requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err)
+ }
+ return oldValue.RpmLimit, nil
+}
+
+// AddRpmLimit adds i to the "rpm_limit" field.
+func (m *GroupMutation) AddRpmLimit(i int) {
+ if m.addrpm_limit != nil {
+ *m.addrpm_limit += i
+ } else {
+ m.addrpm_limit = &i
+ }
+}
+
+// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation.
+func (m *GroupMutation) AddedRpmLimit() (r int, exists bool) {
+ v := m.addrpm_limit
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetRpmLimit resets all changes to the "rpm_limit" field.
+func (m *GroupMutation) ResetRpmLimit() {
+ m.rpm_limit = nil
+ m.addrpm_limit = nil
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -10201,7 +16791,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
- fields := make([]string, 0, 30)
+ fields := make([]string, 0, 31)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -10292,6 +16882,9 @@ func (m *GroupMutation) Fields() []string {
if m.messages_dispatch_model_config != nil {
fields = append(fields, group.FieldMessagesDispatchModelConfig)
}
+ if m.rpm_limit != nil {
+ fields = append(fields, group.FieldRpmLimit)
+ }
return fields
}
@@ -10360,6 +16953,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.DefaultMappedModel()
case group.FieldMessagesDispatchModelConfig:
return m.MessagesDispatchModelConfig()
+ case group.FieldRpmLimit:
+ return m.RpmLimit()
}
return nil, false
}
@@ -10429,6 +17024,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldDefaultMappedModel(ctx)
case group.FieldMessagesDispatchModelConfig:
return m.OldMessagesDispatchModelConfig(ctx)
+ case group.FieldRpmLimit:
+ return m.OldRpmLimit(ctx)
}
return nil, fmt.Errorf("unknown Group field %s", name)
}
@@ -10648,6 +17245,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetMessagesDispatchModelConfig(v)
return nil
+ case group.FieldRpmLimit:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRpmLimit(v)
+ return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -10689,6 +17293,9 @@ func (m *GroupMutation) AddedFields() []string {
if m.addsort_order != nil {
fields = append(fields, group.FieldSortOrder)
}
+ if m.addrpm_limit != nil {
+ fields = append(fields, group.FieldRpmLimit)
+ }
return fields
}
@@ -10719,6 +17326,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedFallbackGroupIDOnInvalidRequest()
case group.FieldSortOrder:
return m.AddedSortOrder()
+ case group.FieldRpmLimit:
+ return m.AddedRpmLimit()
}
return nil, false
}
@@ -10805,6 +17414,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
}
m.AddSortOrder(v)
return nil
+ case group.FieldRpmLimit:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddRpmLimit(v)
+ return nil
}
return fmt.Errorf("unknown Group numeric field %s", name)
}
@@ -10991,6 +17607,9 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldMessagesDispatchModelConfig:
m.ResetMessagesDispatchModelConfig()
return nil
+ case group.FieldRpmLimit:
+ m.ResetRpmLimit()
+ return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -12191,6 +18810,781 @@ func (m *IdempotencyRecordMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown IdempotencyRecord edge %s", name)
}
+// IdentityAdoptionDecisionMutation represents an operation that mutates the IdentityAdoptionDecision nodes in the graph.
+type IdentityAdoptionDecisionMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ adopt_display_name *bool
+ adopt_avatar *bool
+ decided_at *time.Time
+ clearedFields map[string]struct{}
+ pending_auth_session *int64
+ clearedpending_auth_session bool
+ identity *int64
+ clearedidentity bool
+ done bool
+ oldValue func(context.Context) (*IdentityAdoptionDecision, error)
+ predicates []predicate.IdentityAdoptionDecision
+}
+
+var _ ent.Mutation = (*IdentityAdoptionDecisionMutation)(nil)
+
+// identityadoptiondecisionOption allows management of the mutation configuration using functional options.
+type identityadoptiondecisionOption func(*IdentityAdoptionDecisionMutation)
+
+// newIdentityAdoptionDecisionMutation creates new mutation for the IdentityAdoptionDecision entity.
+func newIdentityAdoptionDecisionMutation(c config, op Op, opts ...identityadoptiondecisionOption) *IdentityAdoptionDecisionMutation {
+ m := &IdentityAdoptionDecisionMutation{
+ config: c,
+ op: op,
+ typ: TypeIdentityAdoptionDecision,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withIdentityAdoptionDecisionID sets the ID field of the mutation.
+func withIdentityAdoptionDecisionID(id int64) identityadoptiondecisionOption {
+ return func(m *IdentityAdoptionDecisionMutation) {
+ var (
+ err error
+ once sync.Once
+ value *IdentityAdoptionDecision
+ )
+ m.oldValue = func(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().IdentityAdoptionDecision.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withIdentityAdoptionDecision sets the old IdentityAdoptionDecision of the mutation.
+func withIdentityAdoptionDecision(node *IdentityAdoptionDecision) identityadoptiondecisionOption {
+ return func(m *IdentityAdoptionDecisionMutation) {
+ m.oldValue = func(context.Context) (*IdentityAdoptionDecision, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m IdentityAdoptionDecisionMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m IdentityAdoptionDecisionMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *IdentityAdoptionDecisionMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *IdentityAdoptionDecisionMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().IdentityAdoptionDecision.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (m *IdentityAdoptionDecisionMutation) SetPendingAuthSessionID(i int64) {
+ m.pending_auth_session = &i
+}
+
+// PendingAuthSessionID returns the value of the "pending_auth_session_id" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionID() (r int64, exists bool) {
+ v := m.pending_auth_session
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPendingAuthSessionID returns the old "pending_auth_session_id" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldPendingAuthSessionID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPendingAuthSessionID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPendingAuthSessionID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPendingAuthSessionID: %w", err)
+ }
+ return oldValue.PendingAuthSessionID, nil
+}
+
+// ResetPendingAuthSessionID resets all changes to the "pending_auth_session_id" field.
+func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSessionID() {
+ m.pending_auth_session = nil
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) SetIdentityID(i int64) {
+ m.identity = &i
+}
+
+// IdentityID returns the value of the "identity_id" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) IdentityID() (r int64, exists bool) {
+ v := m.identity
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIdentityID returns the old "identity_id" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldIdentityID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIdentityID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIdentityID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIdentityID: %w", err)
+ }
+ return oldValue.IdentityID, nil
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) ClearIdentityID() {
+ m.identity = nil
+ m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{}
+}
+
+// IdentityIDCleared returns if the "identity_id" field was cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) IdentityIDCleared() bool {
+ _, ok := m.clearedFields[identityadoptiondecision.FieldIdentityID]
+ return ok
+}
+
+// ResetIdentityID resets all changes to the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) ResetIdentityID() {
+ m.identity = nil
+ delete(m.clearedFields, identityadoptiondecision.FieldIdentityID)
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (m *IdentityAdoptionDecisionMutation) SetAdoptDisplayName(b bool) {
+ m.adopt_display_name = &b
+}
+
+// AdoptDisplayName returns the value of the "adopt_display_name" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) AdoptDisplayName() (r bool, exists bool) {
+ v := m.adopt_display_name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAdoptDisplayName returns the old "adopt_display_name" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldAdoptDisplayName(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAdoptDisplayName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAdoptDisplayName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAdoptDisplayName: %w", err)
+ }
+ return oldValue.AdoptDisplayName, nil
+}
+
+// ResetAdoptDisplayName resets all changes to the "adopt_display_name" field.
+func (m *IdentityAdoptionDecisionMutation) ResetAdoptDisplayName() {
+ m.adopt_display_name = nil
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (m *IdentityAdoptionDecisionMutation) SetAdoptAvatar(b bool) {
+ m.adopt_avatar = &b
+}
+
+// AdoptAvatar returns the value of the "adopt_avatar" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) AdoptAvatar() (r bool, exists bool) {
+ v := m.adopt_avatar
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAdoptAvatar returns the old "adopt_avatar" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldAdoptAvatar(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAdoptAvatar is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAdoptAvatar requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAdoptAvatar: %w", err)
+ }
+ return oldValue.AdoptAvatar, nil
+}
+
+// ResetAdoptAvatar resets all changes to the "adopt_avatar" field.
+func (m *IdentityAdoptionDecisionMutation) ResetAdoptAvatar() {
+ m.adopt_avatar = nil
+}
+
+// SetDecidedAt sets the "decided_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetDecidedAt(t time.Time) {
+ m.decided_at = &t
+}
+
+// DecidedAt returns the value of the "decided_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) DecidedAt() (r time.Time, exists bool) {
+ v := m.decided_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDecidedAt returns the old "decided_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldDecidedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDecidedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDecidedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDecidedAt: %w", err)
+ }
+ return oldValue.DecidedAt, nil
+}
+
+// ResetDecidedAt resets all changes to the "decided_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetDecidedAt() {
+ m.decided_at = nil
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (m *IdentityAdoptionDecisionMutation) ClearPendingAuthSession() {
+ m.clearedpending_auth_session = true
+ m.clearedFields[identityadoptiondecision.FieldPendingAuthSessionID] = struct{}{}
+}
+
+// PendingAuthSessionCleared reports if the "pending_auth_session" edge to the PendingAuthSession entity was cleared.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionCleared() bool {
+ return m.clearedpending_auth_session
+}
+
+// PendingAuthSessionIDs returns the "pending_auth_session" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// PendingAuthSessionID instead. It exists only for internal usage by the builders.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionIDs() (ids []int64) {
+ if id := m.pending_auth_session; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetPendingAuthSession resets all changes to the "pending_auth_session" edge.
+func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSession() {
+ m.pending_auth_session = nil
+ m.clearedpending_auth_session = false
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (m *IdentityAdoptionDecisionMutation) ClearIdentity() {
+ m.clearedidentity = true
+ m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{}
+}
+
+// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared.
+func (m *IdentityAdoptionDecisionMutation) IdentityCleared() bool {
+ return m.IdentityIDCleared() || m.clearedidentity
+}
+
+// IdentityIDs returns the "identity" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// IdentityID instead. It exists only for internal usage by the builders.
+func (m *IdentityAdoptionDecisionMutation) IdentityIDs() (ids []int64) {
+ if id := m.identity; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetIdentity resets all changes to the "identity" edge.
+func (m *IdentityAdoptionDecisionMutation) ResetIdentity() {
+ m.identity = nil
+ m.clearedidentity = false
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionMutation builder.
+func (m *IdentityAdoptionDecisionMutation) Where(ps ...predicate.IdentityAdoptionDecision) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the IdentityAdoptionDecisionMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *IdentityAdoptionDecisionMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.IdentityAdoptionDecision, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *IdentityAdoptionDecisionMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *IdentityAdoptionDecisionMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (IdentityAdoptionDecision).
+func (m *IdentityAdoptionDecisionMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *IdentityAdoptionDecisionMutation) Fields() []string {
+ fields := make([]string, 0, 7)
+ if m.created_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldUpdatedAt)
+ }
+ if m.pending_auth_session != nil {
+ fields = append(fields, identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ if m.identity != nil {
+ fields = append(fields, identityadoptiondecision.FieldIdentityID)
+ }
+ if m.adopt_display_name != nil {
+ fields = append(fields, identityadoptiondecision.FieldAdoptDisplayName)
+ }
+ if m.adopt_avatar != nil {
+ fields = append(fields, identityadoptiondecision.FieldAdoptAvatar)
+ }
+ if m.decided_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldDecidedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *IdentityAdoptionDecisionMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ return m.CreatedAt()
+ case identityadoptiondecision.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ return m.PendingAuthSessionID()
+ case identityadoptiondecision.FieldIdentityID:
+ return m.IdentityID()
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ return m.AdoptDisplayName()
+ case identityadoptiondecision.FieldAdoptAvatar:
+ return m.AdoptAvatar()
+ case identityadoptiondecision.FieldDecidedAt:
+ return m.DecidedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *IdentityAdoptionDecisionMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case identityadoptiondecision.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ return m.OldPendingAuthSessionID(ctx)
+ case identityadoptiondecision.FieldIdentityID:
+ return m.OldIdentityID(ctx)
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ return m.OldAdoptDisplayName(ctx)
+ case identityadoptiondecision.FieldAdoptAvatar:
+ return m.OldAdoptAvatar(ctx)
+ case identityadoptiondecision.FieldDecidedAt:
+ return m.OldDecidedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdentityAdoptionDecisionMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case identityadoptiondecision.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPendingAuthSessionID(v)
+ return nil
+ case identityadoptiondecision.FieldIdentityID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIdentityID(v)
+ return nil
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAdoptDisplayName(v)
+ return nil
+ case identityadoptiondecision.FieldAdoptAvatar:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAdoptAvatar(v)
+ return nil
+ case identityadoptiondecision.FieldDecidedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDecidedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdentityAdoptionDecisionMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *IdentityAdoptionDecisionMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(identityadoptiondecision.FieldIdentityID) {
+ fields = append(fields, identityadoptiondecision.FieldIdentityID)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ClearField(name string) error {
+ switch name {
+ case identityadoptiondecision.FieldIdentityID:
+ m.ClearIdentityID()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ResetField(name string) error {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case identityadoptiondecision.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ m.ResetPendingAuthSessionID()
+ return nil
+ case identityadoptiondecision.FieldIdentityID:
+ m.ResetIdentityID()
+ return nil
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ m.ResetAdoptDisplayName()
+ return nil
+ case identityadoptiondecision.FieldAdoptAvatar:
+ m.ResetAdoptAvatar()
+ return nil
+ case identityadoptiondecision.FieldDecidedAt:
+ m.ResetDecidedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.pending_auth_session != nil {
+ edges = append(edges, identityadoptiondecision.EdgePendingAuthSession)
+ }
+ if m.identity != nil {
+ edges = append(edges, identityadoptiondecision.EdgeIdentity)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ if id := m.pending_auth_session; id != nil {
+ return []ent.Value{*id}
+ }
+ case identityadoptiondecision.EdgeIdentity:
+ if id := m.identity; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *IdentityAdoptionDecisionMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 2)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *IdentityAdoptionDecisionMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.clearedpending_auth_session {
+ edges = append(edges, identityadoptiondecision.EdgePendingAuthSession)
+ }
+ if m.clearedidentity {
+ edges = append(edges, identityadoptiondecision.EdgeIdentity)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) EdgeCleared(name string) bool {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ return m.clearedpending_auth_session
+ case identityadoptiondecision.EdgeIdentity:
+ return m.clearedidentity
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ClearEdge(name string) error {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ m.ClearPendingAuthSession()
+ return nil
+ case identityadoptiondecision.EdgeIdentity:
+ m.ClearIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ResetEdge(name string) error {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ m.ResetPendingAuthSession()
+ return nil
+ case identityadoptiondecision.EdgeIdentity:
+ m.ResetIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision edge %s", name)
+}
+
// PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph.
type PaymentAuditLogMutation struct {
config
@@ -12763,6 +20157,8 @@ type PaymentOrderMutation struct {
subscription_days *int
addsubscription_days *int
provider_instance_id *string
+ provider_key *string
+ provider_snapshot *map[string]interface{}
status *string
refund_amount *float64
addrefund_amount *float64
@@ -13799,6 +21195,104 @@ func (m *PaymentOrderMutation) ResetProviderInstanceID() {
delete(m.clearedFields, paymentorder.FieldProviderInstanceID)
}
+// SetProviderKey sets the "provider_key" field.
+func (m *PaymentOrderMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *PaymentOrderMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldProviderKey(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (m *PaymentOrderMutation) ClearProviderKey() {
+ m.provider_key = nil
+ m.clearedFields[paymentorder.FieldProviderKey] = struct{}{}
+}
+
+// ProviderKeyCleared returns if the "provider_key" field was cleared in this mutation.
+func (m *PaymentOrderMutation) ProviderKeyCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldProviderKey]
+ return ok
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *PaymentOrderMutation) ResetProviderKey() {
+ m.provider_key = nil
+ delete(m.clearedFields, paymentorder.FieldProviderKey)
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (m *PaymentOrderMutation) SetProviderSnapshot(value map[string]interface{}) {
+ m.provider_snapshot = &value
+}
+
+// ProviderSnapshot returns the value of the "provider_snapshot" field in the mutation.
+func (m *PaymentOrderMutation) ProviderSnapshot() (r map[string]interface{}, exists bool) {
+ v := m.provider_snapshot
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSnapshot returns the old "provider_snapshot" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldProviderSnapshot(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSnapshot is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSnapshot requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSnapshot: %w", err)
+ }
+ return oldValue.ProviderSnapshot, nil
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (m *PaymentOrderMutation) ClearProviderSnapshot() {
+ m.provider_snapshot = nil
+ m.clearedFields[paymentorder.FieldProviderSnapshot] = struct{}{}
+}
+
+// ProviderSnapshotCleared returns if the "provider_snapshot" field was cleared in this mutation.
+func (m *PaymentOrderMutation) ProviderSnapshotCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldProviderSnapshot]
+ return ok
+}
+
+// ResetProviderSnapshot resets all changes to the "provider_snapshot" field.
+func (m *PaymentOrderMutation) ResetProviderSnapshot() {
+ m.provider_snapshot = nil
+ delete(m.clearedFields, paymentorder.FieldProviderSnapshot)
+}
+
// SetStatus sets the "status" field.
func (m *PaymentOrderMutation) SetStatus(s string) {
m.status = &s
@@ -14658,7 +22152,7 @@ func (m *PaymentOrderMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *PaymentOrderMutation) Fields() []string {
- fields := make([]string, 0, 37)
+ fields := make([]string, 0, 39)
if m.user != nil {
fields = append(fields, paymentorder.FieldUserID)
}
@@ -14716,6 +22210,12 @@ func (m *PaymentOrderMutation) Fields() []string {
if m.provider_instance_id != nil {
fields = append(fields, paymentorder.FieldProviderInstanceID)
}
+ if m.provider_key != nil {
+ fields = append(fields, paymentorder.FieldProviderKey)
+ }
+ if m.provider_snapshot != nil {
+ fields = append(fields, paymentorder.FieldProviderSnapshot)
+ }
if m.status != nil {
fields = append(fields, paymentorder.FieldStatus)
}
@@ -14816,6 +22316,10 @@ func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) {
return m.SubscriptionDays()
case paymentorder.FieldProviderInstanceID:
return m.ProviderInstanceID()
+ case paymentorder.FieldProviderKey:
+ return m.ProviderKey()
+ case paymentorder.FieldProviderSnapshot:
+ return m.ProviderSnapshot()
case paymentorder.FieldStatus:
return m.Status()
case paymentorder.FieldRefundAmount:
@@ -14899,6 +22403,10 @@ func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.V
return m.OldSubscriptionDays(ctx)
case paymentorder.FieldProviderInstanceID:
return m.OldProviderInstanceID(ctx)
+ case paymentorder.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case paymentorder.FieldProviderSnapshot:
+ return m.OldProviderSnapshot(ctx)
case paymentorder.FieldStatus:
return m.OldStatus(ctx)
case paymentorder.FieldRefundAmount:
@@ -15077,6 +22585,20 @@ func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error {
}
m.SetProviderInstanceID(v)
return nil
+ case paymentorder.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case paymentorder.FieldProviderSnapshot:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSnapshot(v)
+ return nil
case paymentorder.FieldStatus:
v, ok := value.(string)
if !ok {
@@ -15344,6 +22866,12 @@ func (m *PaymentOrderMutation) ClearedFields() []string {
if m.FieldCleared(paymentorder.FieldProviderInstanceID) {
fields = append(fields, paymentorder.FieldProviderInstanceID)
}
+ if m.FieldCleared(paymentorder.FieldProviderKey) {
+ fields = append(fields, paymentorder.FieldProviderKey)
+ }
+ if m.FieldCleared(paymentorder.FieldProviderSnapshot) {
+ fields = append(fields, paymentorder.FieldProviderSnapshot)
+ }
if m.FieldCleared(paymentorder.FieldRefundReason) {
fields = append(fields, paymentorder.FieldRefundReason)
}
@@ -15412,6 +22940,12 @@ func (m *PaymentOrderMutation) ClearField(name string) error {
case paymentorder.FieldProviderInstanceID:
m.ClearProviderInstanceID()
return nil
+ case paymentorder.FieldProviderKey:
+ m.ClearProviderKey()
+ return nil
+ case paymentorder.FieldProviderSnapshot:
+ m.ClearProviderSnapshot()
+ return nil
case paymentorder.FieldRefundReason:
m.ClearRefundReason()
return nil
@@ -15507,6 +23041,12 @@ func (m *PaymentOrderMutation) ResetField(name string) error {
case paymentorder.FieldProviderInstanceID:
m.ResetProviderInstanceID()
return nil
+ case paymentorder.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case paymentorder.FieldProviderSnapshot:
+ m.ResetProviderSnapshot()
+ return nil
case paymentorder.FieldStatus:
m.ResetStatus()
return nil
@@ -16595,6 +24135,1645 @@ func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown PaymentProviderInstance edge %s", name)
}
+// PendingAuthSessionMutation represents an operation that mutates the PendingAuthSession nodes in the graph.
+type PendingAuthSessionMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ session_token *string
+ intent *string
+ provider_type *string
+ provider_key *string
+ provider_subject *string
+ redirect_to *string
+ resolved_email *string
+ registration_password_hash *string
+ upstream_identity_claims *map[string]interface{}
+ local_flow_state *map[string]interface{}
+ browser_session_key *string
+ completion_code_hash *string
+ completion_code_expires_at *time.Time
+ email_verified_at *time.Time
+ password_verified_at *time.Time
+ totp_verified_at *time.Time
+ expires_at *time.Time
+ consumed_at *time.Time
+ clearedFields map[string]struct{}
+ target_user *int64
+ clearedtarget_user bool
+ adoption_decision *int64
+ clearedadoption_decision bool
+ done bool
+ oldValue func(context.Context) (*PendingAuthSession, error)
+ predicates []predicate.PendingAuthSession
+}
+
+var _ ent.Mutation = (*PendingAuthSessionMutation)(nil)
+
+// pendingauthsessionOption allows management of the mutation configuration using functional options.
+type pendingauthsessionOption func(*PendingAuthSessionMutation)
+
+// newPendingAuthSessionMutation creates new mutation for the PendingAuthSession entity.
+func newPendingAuthSessionMutation(c config, op Op, opts ...pendingauthsessionOption) *PendingAuthSessionMutation {
+ m := &PendingAuthSessionMutation{
+ config: c,
+ op: op,
+ typ: TypePendingAuthSession,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withPendingAuthSessionID sets the ID field of the mutation.
+func withPendingAuthSessionID(id int64) pendingauthsessionOption {
+ return func(m *PendingAuthSessionMutation) {
+ var (
+ err error
+ once sync.Once
+ value *PendingAuthSession
+ )
+ m.oldValue = func(ctx context.Context) (*PendingAuthSession, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().PendingAuthSession.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withPendingAuthSession sets the old PendingAuthSession of the mutation.
+func withPendingAuthSession(node *PendingAuthSession) pendingauthsessionOption {
+ return func(m *PendingAuthSessionMutation) {
+ m.oldValue = func(context.Context) (*PendingAuthSession, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m PendingAuthSessionMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m PendingAuthSessionMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *PendingAuthSessionMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *PendingAuthSessionMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().PendingAuthSession.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *PendingAuthSessionMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *PendingAuthSessionMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *PendingAuthSessionMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *PendingAuthSessionMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *PendingAuthSessionMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *PendingAuthSessionMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetSessionToken sets the "session_token" field.
+func (m *PendingAuthSessionMutation) SetSessionToken(s string) {
+ m.session_token = &s
+}
+
+// SessionToken returns the value of the "session_token" field in the mutation.
+func (m *PendingAuthSessionMutation) SessionToken() (r string, exists bool) {
+ v := m.session_token
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSessionToken returns the old "session_token" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldSessionToken(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSessionToken is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSessionToken requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSessionToken: %w", err)
+ }
+ return oldValue.SessionToken, nil
+}
+
+// ResetSessionToken resets all changes to the "session_token" field.
+func (m *PendingAuthSessionMutation) ResetSessionToken() {
+ m.session_token = nil
+}
+
+// SetIntent sets the "intent" field.
+func (m *PendingAuthSessionMutation) SetIntent(s string) {
+ m.intent = &s
+}
+
+// Intent returns the value of the "intent" field in the mutation.
+func (m *PendingAuthSessionMutation) Intent() (r string, exists bool) {
+ v := m.intent
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIntent returns the old "intent" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldIntent(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIntent is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIntent requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIntent: %w", err)
+ }
+ return oldValue.Intent, nil
+}
+
+// ResetIntent resets all changes to the "intent" field.
+func (m *PendingAuthSessionMutation) ResetIntent() {
+ m.intent = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *PendingAuthSessionMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *PendingAuthSessionMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *PendingAuthSessionMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *PendingAuthSessionMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (m *PendingAuthSessionMutation) SetProviderSubject(s string) {
+ m.provider_subject = &s
+}
+
+// ProviderSubject returns the value of the "provider_subject" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderSubject() (r string, exists bool) {
+ v := m.provider_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSubject returns the old "provider_subject" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err)
+ }
+ return oldValue.ProviderSubject, nil
+}
+
+// ResetProviderSubject resets all changes to the "provider_subject" field.
+func (m *PendingAuthSessionMutation) ResetProviderSubject() {
+ m.provider_subject = nil
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (m *PendingAuthSessionMutation) SetTargetUserID(i int64) {
+ m.target_user = &i
+}
+
+// TargetUserID returns the value of the "target_user_id" field in the mutation.
+func (m *PendingAuthSessionMutation) TargetUserID() (r int64, exists bool) {
+ v := m.target_user
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTargetUserID returns the old "target_user_id" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldTargetUserID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTargetUserID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTargetUserID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTargetUserID: %w", err)
+ }
+ return oldValue.TargetUserID, nil
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (m *PendingAuthSessionMutation) ClearTargetUserID() {
+ m.target_user = nil
+ m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{}
+}
+
+// TargetUserIDCleared returns if the "target_user_id" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) TargetUserIDCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldTargetUserID]
+ return ok
+}
+
+// ResetTargetUserID resets all changes to the "target_user_id" field.
+func (m *PendingAuthSessionMutation) ResetTargetUserID() {
+ m.target_user = nil
+ delete(m.clearedFields, pendingauthsession.FieldTargetUserID)
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (m *PendingAuthSessionMutation) SetRedirectTo(s string) {
+ m.redirect_to = &s
+}
+
+// RedirectTo returns the value of the "redirect_to" field in the mutation.
+func (m *PendingAuthSessionMutation) RedirectTo() (r string, exists bool) {
+ v := m.redirect_to
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRedirectTo returns the old "redirect_to" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldRedirectTo(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRedirectTo is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRedirectTo requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRedirectTo: %w", err)
+ }
+ return oldValue.RedirectTo, nil
+}
+
+// ResetRedirectTo resets all changes to the "redirect_to" field.
+func (m *PendingAuthSessionMutation) ResetRedirectTo() {
+ m.redirect_to = nil
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (m *PendingAuthSessionMutation) SetResolvedEmail(s string) {
+ m.resolved_email = &s
+}
+
+// ResolvedEmail returns the value of the "resolved_email" field in the mutation.
+func (m *PendingAuthSessionMutation) ResolvedEmail() (r string, exists bool) {
+ v := m.resolved_email
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldResolvedEmail returns the old "resolved_email" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldResolvedEmail(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldResolvedEmail is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldResolvedEmail requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldResolvedEmail: %w", err)
+ }
+ return oldValue.ResolvedEmail, nil
+}
+
+// ResetResolvedEmail resets all changes to the "resolved_email" field.
+func (m *PendingAuthSessionMutation) ResetResolvedEmail() {
+ m.resolved_email = nil
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (m *PendingAuthSessionMutation) SetRegistrationPasswordHash(s string) {
+ m.registration_password_hash = &s
+}
+
+// RegistrationPasswordHash returns the value of the "registration_password_hash" field in the mutation.
+func (m *PendingAuthSessionMutation) RegistrationPasswordHash() (r string, exists bool) {
+ v := m.registration_password_hash
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRegistrationPasswordHash returns the old "registration_password_hash" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldRegistrationPasswordHash(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRegistrationPasswordHash is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRegistrationPasswordHash requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRegistrationPasswordHash: %w", err)
+ }
+ return oldValue.RegistrationPasswordHash, nil
+}
+
+// ResetRegistrationPasswordHash resets all changes to the "registration_password_hash" field.
+func (m *PendingAuthSessionMutation) ResetRegistrationPasswordHash() {
+ m.registration_password_hash = nil
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (m *PendingAuthSessionMutation) SetUpstreamIdentityClaims(value map[string]interface{}) {
+ m.upstream_identity_claims = &value
+}
+
+// UpstreamIdentityClaims returns the value of the "upstream_identity_claims" field in the mutation.
+func (m *PendingAuthSessionMutation) UpstreamIdentityClaims() (r map[string]interface{}, exists bool) {
+ v := m.upstream_identity_claims
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpstreamIdentityClaims returns the old "upstream_identity_claims" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldUpstreamIdentityClaims(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpstreamIdentityClaims is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpstreamIdentityClaims requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpstreamIdentityClaims: %w", err)
+ }
+ return oldValue.UpstreamIdentityClaims, nil
+}
+
+// ResetUpstreamIdentityClaims resets all changes to the "upstream_identity_claims" field.
+func (m *PendingAuthSessionMutation) ResetUpstreamIdentityClaims() {
+ m.upstream_identity_claims = nil
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (m *PendingAuthSessionMutation) SetLocalFlowState(value map[string]interface{}) {
+ m.local_flow_state = &value
+}
+
+// LocalFlowState returns the value of the "local_flow_state" field in the mutation.
+func (m *PendingAuthSessionMutation) LocalFlowState() (r map[string]interface{}, exists bool) {
+ v := m.local_flow_state
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLocalFlowState returns the old "local_flow_state" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldLocalFlowState(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLocalFlowState is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLocalFlowState requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLocalFlowState: %w", err)
+ }
+ return oldValue.LocalFlowState, nil
+}
+
+// ResetLocalFlowState resets all changes to the "local_flow_state" field.
+func (m *PendingAuthSessionMutation) ResetLocalFlowState() {
+ m.local_flow_state = nil
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (m *PendingAuthSessionMutation) SetBrowserSessionKey(s string) {
+ m.browser_session_key = &s
+}
+
+// BrowserSessionKey returns the value of the "browser_session_key" field in the mutation.
+func (m *PendingAuthSessionMutation) BrowserSessionKey() (r string, exists bool) {
+ v := m.browser_session_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBrowserSessionKey returns the old "browser_session_key" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldBrowserSessionKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBrowserSessionKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBrowserSessionKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBrowserSessionKey: %w", err)
+ }
+ return oldValue.BrowserSessionKey, nil
+}
+
+// ResetBrowserSessionKey resets all changes to the "browser_session_key" field.
+func (m *PendingAuthSessionMutation) ResetBrowserSessionKey() {
+ m.browser_session_key = nil
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (m *PendingAuthSessionMutation) SetCompletionCodeHash(s string) {
+ m.completion_code_hash = &s
+}
+
+// CompletionCodeHash returns the value of the "completion_code_hash" field in the mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeHash() (r string, exists bool) {
+ v := m.completion_code_hash
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCompletionCodeHash returns the old "completion_code_hash" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCompletionCodeHash(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCompletionCodeHash is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCompletionCodeHash requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCompletionCodeHash: %w", err)
+ }
+ return oldValue.CompletionCodeHash, nil
+}
+
+// ResetCompletionCodeHash resets all changes to the "completion_code_hash" field.
+func (m *PendingAuthSessionMutation) ResetCompletionCodeHash() {
+ m.completion_code_hash = nil
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) SetCompletionCodeExpiresAt(t time.Time) {
+ m.completion_code_expires_at = &t
+}
+
+// CompletionCodeExpiresAt returns the value of the "completion_code_expires_at" field in the mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeExpiresAt() (r time.Time, exists bool) {
+ v := m.completion_code_expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCompletionCodeExpiresAt returns the old "completion_code_expires_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCompletionCodeExpiresAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCompletionCodeExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCompletionCodeExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCompletionCodeExpiresAt: %w", err)
+ }
+ return oldValue.CompletionCodeExpiresAt, nil
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) ClearCompletionCodeExpiresAt() {
+ m.completion_code_expires_at = nil
+ m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] = struct{}{}
+}
+
+// CompletionCodeExpiresAtCleared returns if the "completion_code_expires_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeExpiresAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt]
+ return ok
+}
+
+// ResetCompletionCodeExpiresAt resets all changes to the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) ResetCompletionCodeExpiresAt() {
+ m.completion_code_expires_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldCompletionCodeExpiresAt)
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) SetEmailVerifiedAt(t time.Time) {
+ m.email_verified_at = &t
+}
+
+// EmailVerifiedAt returns the value of the "email_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) EmailVerifiedAt() (r time.Time, exists bool) {
+ v := m.email_verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEmailVerifiedAt returns the old "email_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldEmailVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEmailVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEmailVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEmailVerifiedAt: %w", err)
+ }
+ return oldValue.EmailVerifiedAt, nil
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearEmailVerifiedAt() {
+ m.email_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] = struct{}{}
+}
+
+// EmailVerifiedAtCleared returns if the "email_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) EmailVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldEmailVerifiedAt]
+ return ok
+}
+
+// ResetEmailVerifiedAt resets all changes to the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetEmailVerifiedAt() {
+ m.email_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldEmailVerifiedAt)
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) SetPasswordVerifiedAt(t time.Time) {
+ m.password_verified_at = &t
+}
+
+// PasswordVerifiedAt returns the value of the "password_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) PasswordVerifiedAt() (r time.Time, exists bool) {
+ v := m.password_verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPasswordVerifiedAt returns the old "password_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldPasswordVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPasswordVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPasswordVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPasswordVerifiedAt: %w", err)
+ }
+ return oldValue.PasswordVerifiedAt, nil
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearPasswordVerifiedAt() {
+ m.password_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] = struct{}{}
+}
+
+// PasswordVerifiedAtCleared returns if the "password_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) PasswordVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt]
+ return ok
+}
+
+// ResetPasswordVerifiedAt resets all changes to the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetPasswordVerifiedAt() {
+ m.password_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldPasswordVerifiedAt)
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) SetTotpVerifiedAt(t time.Time) {
+ m.totp_verified_at = &t
+}
+
+// TotpVerifiedAt returns the value of the "totp_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) TotpVerifiedAt() (r time.Time, exists bool) {
+ v := m.totp_verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotpVerifiedAt returns the old "totp_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldTotpVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotpVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotpVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotpVerifiedAt: %w", err)
+ }
+ return oldValue.TotpVerifiedAt, nil
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearTotpVerifiedAt() {
+ m.totp_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] = struct{}{}
+}
+
+// TotpVerifiedAtCleared returns if the "totp_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) TotpVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldTotpVerifiedAt]
+ return ok
+}
+
+// ResetTotpVerifiedAt resets all changes to the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetTotpVerifiedAt() {
+ m.totp_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldTotpVerifiedAt)
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (m *PendingAuthSessionMutation) SetExpiresAt(t time.Time) {
+ m.expires_at = &t
+}
+
+// ExpiresAt returns the value of the "expires_at" field in the mutation.
+func (m *PendingAuthSessionMutation) ExpiresAt() (r time.Time, exists bool) {
+ v := m.expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExpiresAt returns the old "expires_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
+ }
+ return oldValue.ExpiresAt, nil
+}
+
+// ResetExpiresAt resets all changes to the "expires_at" field.
+func (m *PendingAuthSessionMutation) ResetExpiresAt() {
+ m.expires_at = nil
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (m *PendingAuthSessionMutation) SetConsumedAt(t time.Time) {
+ m.consumed_at = &t
+}
+
+// ConsumedAt returns the value of the "consumed_at" field in the mutation.
+func (m *PendingAuthSessionMutation) ConsumedAt() (r time.Time, exists bool) {
+ v := m.consumed_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldConsumedAt returns the old "consumed_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldConsumedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldConsumedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldConsumedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldConsumedAt: %w", err)
+ }
+ return oldValue.ConsumedAt, nil
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (m *PendingAuthSessionMutation) ClearConsumedAt() {
+ m.consumed_at = nil
+ m.clearedFields[pendingauthsession.FieldConsumedAt] = struct{}{}
+}
+
+// ConsumedAtCleared returns if the "consumed_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) ConsumedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldConsumedAt]
+ return ok
+}
+
+// ResetConsumedAt resets all changes to the "consumed_at" field.
+func (m *PendingAuthSessionMutation) ResetConsumedAt() {
+ m.consumed_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldConsumedAt)
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (m *PendingAuthSessionMutation) ClearTargetUser() {
+ m.clearedtarget_user = true
+ m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{}
+}
+
+// TargetUserCleared reports if the "target_user" edge to the User entity was cleared.
+func (m *PendingAuthSessionMutation) TargetUserCleared() bool {
+ return m.TargetUserIDCleared() || m.clearedtarget_user
+}
+
+// TargetUserIDs returns the "target_user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// TargetUserID instead. It exists only for internal usage by the builders.
+func (m *PendingAuthSessionMutation) TargetUserIDs() (ids []int64) {
+ if id := m.target_user; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetTargetUser resets all changes to the "target_user" edge.
+func (m *PendingAuthSessionMutation) ResetTargetUser() {
+ m.target_user = nil
+ m.clearedtarget_user = false
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by id.
+func (m *PendingAuthSessionMutation) SetAdoptionDecisionID(id int64) {
+ m.adoption_decision = &id
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (m *PendingAuthSessionMutation) ClearAdoptionDecision() {
+ m.clearedadoption_decision = true
+}
+
+// AdoptionDecisionCleared reports if the "adoption_decision" edge to the IdentityAdoptionDecision entity was cleared.
+func (m *PendingAuthSessionMutation) AdoptionDecisionCleared() bool {
+ return m.clearedadoption_decision
+}
+
+// AdoptionDecisionID returns the "adoption_decision" edge ID in the mutation.
+func (m *PendingAuthSessionMutation) AdoptionDecisionID() (id int64, exists bool) {
+ if m.adoption_decision != nil {
+ return *m.adoption_decision, true
+ }
+ return
+}
+
+// AdoptionDecisionIDs returns the "adoption_decision" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// AdoptionDecisionID instead. It exists only for internal usage by the builders.
+func (m *PendingAuthSessionMutation) AdoptionDecisionIDs() (ids []int64) {
+ if id := m.adoption_decision; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetAdoptionDecision resets all changes to the "adoption_decision" edge.
+func (m *PendingAuthSessionMutation) ResetAdoptionDecision() {
+ m.adoption_decision = nil
+ m.clearedadoption_decision = false
+}
+
+// Where appends a list predicates to the PendingAuthSessionMutation builder.
+func (m *PendingAuthSessionMutation) Where(ps ...predicate.PendingAuthSession) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the PendingAuthSessionMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *PendingAuthSessionMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.PendingAuthSession, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *PendingAuthSessionMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *PendingAuthSessionMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (PendingAuthSession).
+func (m *PendingAuthSessionMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *PendingAuthSessionMutation) Fields() []string {
+ fields := make([]string, 0, 21)
+ if m.created_at != nil {
+ fields = append(fields, pendingauthsession.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, pendingauthsession.FieldUpdatedAt)
+ }
+ if m.session_token != nil {
+ fields = append(fields, pendingauthsession.FieldSessionToken)
+ }
+ if m.intent != nil {
+ fields = append(fields, pendingauthsession.FieldIntent)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, pendingauthsession.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, pendingauthsession.FieldProviderKey)
+ }
+ if m.provider_subject != nil {
+ fields = append(fields, pendingauthsession.FieldProviderSubject)
+ }
+ if m.target_user != nil {
+ fields = append(fields, pendingauthsession.FieldTargetUserID)
+ }
+ if m.redirect_to != nil {
+ fields = append(fields, pendingauthsession.FieldRedirectTo)
+ }
+ if m.resolved_email != nil {
+ fields = append(fields, pendingauthsession.FieldResolvedEmail)
+ }
+ if m.registration_password_hash != nil {
+ fields = append(fields, pendingauthsession.FieldRegistrationPasswordHash)
+ }
+ if m.upstream_identity_claims != nil {
+ fields = append(fields, pendingauthsession.FieldUpstreamIdentityClaims)
+ }
+ if m.local_flow_state != nil {
+ fields = append(fields, pendingauthsession.FieldLocalFlowState)
+ }
+ if m.browser_session_key != nil {
+ fields = append(fields, pendingauthsession.FieldBrowserSessionKey)
+ }
+ if m.completion_code_hash != nil {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeHash)
+ }
+ if m.completion_code_expires_at != nil {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt)
+ }
+ if m.email_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldEmailVerifiedAt)
+ }
+ if m.password_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt)
+ }
+ if m.totp_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldTotpVerifiedAt)
+ }
+ if m.expires_at != nil {
+ fields = append(fields, pendingauthsession.FieldExpiresAt)
+ }
+ if m.consumed_at != nil {
+ fields = append(fields, pendingauthsession.FieldConsumedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *PendingAuthSessionMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ return m.CreatedAt()
+ case pendingauthsession.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case pendingauthsession.FieldSessionToken:
+ return m.SessionToken()
+ case pendingauthsession.FieldIntent:
+ return m.Intent()
+ case pendingauthsession.FieldProviderType:
+ return m.ProviderType()
+ case pendingauthsession.FieldProviderKey:
+ return m.ProviderKey()
+ case pendingauthsession.FieldProviderSubject:
+ return m.ProviderSubject()
+ case pendingauthsession.FieldTargetUserID:
+ return m.TargetUserID()
+ case pendingauthsession.FieldRedirectTo:
+ return m.RedirectTo()
+ case pendingauthsession.FieldResolvedEmail:
+ return m.ResolvedEmail()
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ return m.RegistrationPasswordHash()
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ return m.UpstreamIdentityClaims()
+ case pendingauthsession.FieldLocalFlowState:
+ return m.LocalFlowState()
+ case pendingauthsession.FieldBrowserSessionKey:
+ return m.BrowserSessionKey()
+ case pendingauthsession.FieldCompletionCodeHash:
+ return m.CompletionCodeHash()
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ return m.CompletionCodeExpiresAt()
+ case pendingauthsession.FieldEmailVerifiedAt:
+ return m.EmailVerifiedAt()
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ return m.PasswordVerifiedAt()
+ case pendingauthsession.FieldTotpVerifiedAt:
+ return m.TotpVerifiedAt()
+ case pendingauthsession.FieldExpiresAt:
+ return m.ExpiresAt()
+ case pendingauthsession.FieldConsumedAt:
+ return m.ConsumedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *PendingAuthSessionMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case pendingauthsession.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case pendingauthsession.FieldSessionToken:
+ return m.OldSessionToken(ctx)
+ case pendingauthsession.FieldIntent:
+ return m.OldIntent(ctx)
+ case pendingauthsession.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case pendingauthsession.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case pendingauthsession.FieldProviderSubject:
+ return m.OldProviderSubject(ctx)
+ case pendingauthsession.FieldTargetUserID:
+ return m.OldTargetUserID(ctx)
+ case pendingauthsession.FieldRedirectTo:
+ return m.OldRedirectTo(ctx)
+ case pendingauthsession.FieldResolvedEmail:
+ return m.OldResolvedEmail(ctx)
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ return m.OldRegistrationPasswordHash(ctx)
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ return m.OldUpstreamIdentityClaims(ctx)
+ case pendingauthsession.FieldLocalFlowState:
+ return m.OldLocalFlowState(ctx)
+ case pendingauthsession.FieldBrowserSessionKey:
+ return m.OldBrowserSessionKey(ctx)
+ case pendingauthsession.FieldCompletionCodeHash:
+ return m.OldCompletionCodeHash(ctx)
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ return m.OldCompletionCodeExpiresAt(ctx)
+ case pendingauthsession.FieldEmailVerifiedAt:
+ return m.OldEmailVerifiedAt(ctx)
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ return m.OldPasswordVerifiedAt(ctx)
+ case pendingauthsession.FieldTotpVerifiedAt:
+ return m.OldTotpVerifiedAt(ctx)
+ case pendingauthsession.FieldExpiresAt:
+ return m.OldExpiresAt(ctx)
+ case pendingauthsession.FieldConsumedAt:
+ return m.OldConsumedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown PendingAuthSession field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PendingAuthSessionMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case pendingauthsession.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case pendingauthsession.FieldSessionToken:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSessionToken(v)
+ return nil
+ case pendingauthsession.FieldIntent:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIntent(v)
+ return nil
+ case pendingauthsession.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case pendingauthsession.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case pendingauthsession.FieldProviderSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSubject(v)
+ return nil
+ case pendingauthsession.FieldTargetUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTargetUserID(v)
+ return nil
+ case pendingauthsession.FieldRedirectTo:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRedirectTo(v)
+ return nil
+ case pendingauthsession.FieldResolvedEmail:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetResolvedEmail(v)
+ return nil
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRegistrationPasswordHash(v)
+ return nil
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpstreamIdentityClaims(v)
+ return nil
+ case pendingauthsession.FieldLocalFlowState:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLocalFlowState(v)
+ return nil
+ case pendingauthsession.FieldBrowserSessionKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBrowserSessionKey(v)
+ return nil
+ case pendingauthsession.FieldCompletionCodeHash:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCompletionCodeHash(v)
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCompletionCodeExpiresAt(v)
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEmailVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPasswordVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotpVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExpiresAt(v)
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetConsumedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *PendingAuthSessionMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *PendingAuthSessionMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PendingAuthSessionMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown PendingAuthSession numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *PendingAuthSessionMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(pendingauthsession.FieldTargetUserID) {
+ fields = append(fields, pendingauthsession.FieldTargetUserID)
+ }
+ if m.FieldCleared(pendingauthsession.FieldCompletionCodeExpiresAt) {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldEmailVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldEmailVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldPasswordVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldTotpVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldTotpVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldConsumedAt) {
+ fields = append(fields, pendingauthsession.FieldConsumedAt)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *PendingAuthSessionMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *PendingAuthSessionMutation) ClearField(name string) error {
+ switch name {
+ case pendingauthsession.FieldTargetUserID:
+ m.ClearTargetUserID()
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ m.ClearCompletionCodeExpiresAt()
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ m.ClearEmailVerifiedAt()
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ m.ClearPasswordVerifiedAt()
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ m.ClearTotpVerifiedAt()
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ m.ClearConsumedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *PendingAuthSessionMutation) ResetField(name string) error {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case pendingauthsession.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case pendingauthsession.FieldSessionToken:
+ m.ResetSessionToken()
+ return nil
+ case pendingauthsession.FieldIntent:
+ m.ResetIntent()
+ return nil
+ case pendingauthsession.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case pendingauthsession.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case pendingauthsession.FieldProviderSubject:
+ m.ResetProviderSubject()
+ return nil
+ case pendingauthsession.FieldTargetUserID:
+ m.ResetTargetUserID()
+ return nil
+ case pendingauthsession.FieldRedirectTo:
+ m.ResetRedirectTo()
+ return nil
+ case pendingauthsession.FieldResolvedEmail:
+ m.ResetResolvedEmail()
+ return nil
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ m.ResetRegistrationPasswordHash()
+ return nil
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ m.ResetUpstreamIdentityClaims()
+ return nil
+ case pendingauthsession.FieldLocalFlowState:
+ m.ResetLocalFlowState()
+ return nil
+ case pendingauthsession.FieldBrowserSessionKey:
+ m.ResetBrowserSessionKey()
+ return nil
+ case pendingauthsession.FieldCompletionCodeHash:
+ m.ResetCompletionCodeHash()
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ m.ResetCompletionCodeExpiresAt()
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ m.ResetEmailVerifiedAt()
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ m.ResetPasswordVerifiedAt()
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ m.ResetTotpVerifiedAt()
+ return nil
+ case pendingauthsession.FieldExpiresAt:
+ m.ResetExpiresAt()
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ m.ResetConsumedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *PendingAuthSessionMutation) AddedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.target_user != nil {
+ edges = append(edges, pendingauthsession.EdgeTargetUser)
+ }
+ if m.adoption_decision != nil {
+ edges = append(edges, pendingauthsession.EdgeAdoptionDecision)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *PendingAuthSessionMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ if id := m.target_user; id != nil {
+ return []ent.Value{*id}
+ }
+ case pendingauthsession.EdgeAdoptionDecision:
+ if id := m.adoption_decision; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *PendingAuthSessionMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 2)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *PendingAuthSessionMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *PendingAuthSessionMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.clearedtarget_user {
+ edges = append(edges, pendingauthsession.EdgeTargetUser)
+ }
+ if m.clearedadoption_decision {
+ edges = append(edges, pendingauthsession.EdgeAdoptionDecision)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *PendingAuthSessionMutation) EdgeCleared(name string) bool {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ return m.clearedtarget_user
+ case pendingauthsession.EdgeAdoptionDecision:
+ return m.clearedadoption_decision
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *PendingAuthSessionMutation) ClearEdge(name string) error {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ m.ClearTargetUser()
+ return nil
+ case pendingauthsession.EdgeAdoptionDecision:
+ m.ClearAdoptionDecision()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *PendingAuthSessionMutation) ResetEdge(name string) error {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ m.ResetTargetUser()
+ return nil
+ case pendingauthsession.EdgeAdoptionDecision:
+ m.ResetAdoptionDecision()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession edge %s", name)
+}
+
// PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph.
type PromoCodeMutation struct {
config
@@ -28264,6 +37443,9 @@ type UserMutation struct {
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
+ signup_source *string
+ last_login_at *time.Time
+ last_active_at *time.Time
balance_notify_enabled *bool
balance_notify_threshold_type *string
balance_notify_threshold *float64
@@ -28271,6 +37453,8 @@ type UserMutation struct {
balance_notify_extra_emails *string
total_recharged *float64
addtotal_recharged *float64
+ rpm_limit *int
+ addrpm_limit *int
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -28302,6 +37486,12 @@ type UserMutation struct {
payment_orders map[int64]struct{}
removedpayment_orders map[int64]struct{}
clearedpayment_orders bool
+ auth_identities map[int64]struct{}
+ removedauth_identities map[int64]struct{}
+ clearedauth_identities bool
+ pending_auth_sessions map[int64]struct{}
+ removedpending_auth_sessions map[int64]struct{}
+ clearedpending_auth_sessions bool
done bool
oldValue func(context.Context) (*User, error)
predicates []predicate.User
@@ -28988,6 +38178,140 @@ func (m *UserMutation) ResetTotpEnabledAt() {
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
+// SetSignupSource sets the "signup_source" field.
+func (m *UserMutation) SetSignupSource(s string) {
+ m.signup_source = &s
+}
+
+// SignupSource returns the value of the "signup_source" field in the mutation.
+func (m *UserMutation) SignupSource() (r string, exists bool) {
+ v := m.signup_source
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSignupSource returns the old "signup_source" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldSignupSource(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSignupSource is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSignupSource requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSignupSource: %w", err)
+ }
+ return oldValue.SignupSource, nil
+}
+
+// ResetSignupSource resets all changes to the "signup_source" field.
+func (m *UserMutation) ResetSignupSource() {
+ m.signup_source = nil
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (m *UserMutation) SetLastLoginAt(t time.Time) {
+ m.last_login_at = &t
+}
+
+// LastLoginAt returns the value of the "last_login_at" field in the mutation.
+func (m *UserMutation) LastLoginAt() (r time.Time, exists bool) {
+ v := m.last_login_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLastLoginAt returns the old "last_login_at" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldLastLoginAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLastLoginAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLastLoginAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLastLoginAt: %w", err)
+ }
+ return oldValue.LastLoginAt, nil
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (m *UserMutation) ClearLastLoginAt() {
+ m.last_login_at = nil
+ m.clearedFields[user.FieldLastLoginAt] = struct{}{}
+}
+
+// LastLoginAtCleared returns if the "last_login_at" field was cleared in this mutation.
+func (m *UserMutation) LastLoginAtCleared() bool {
+ _, ok := m.clearedFields[user.FieldLastLoginAt]
+ return ok
+}
+
+// ResetLastLoginAt resets all changes to the "last_login_at" field.
+func (m *UserMutation) ResetLastLoginAt() {
+ m.last_login_at = nil
+ delete(m.clearedFields, user.FieldLastLoginAt)
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (m *UserMutation) SetLastActiveAt(t time.Time) {
+ m.last_active_at = &t
+}
+
+// LastActiveAt returns the value of the "last_active_at" field in the mutation.
+func (m *UserMutation) LastActiveAt() (r time.Time, exists bool) {
+ v := m.last_active_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLastActiveAt returns the old "last_active_at" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldLastActiveAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLastActiveAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLastActiveAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLastActiveAt: %w", err)
+ }
+ return oldValue.LastActiveAt, nil
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (m *UserMutation) ClearLastActiveAt() {
+ m.last_active_at = nil
+ m.clearedFields[user.FieldLastActiveAt] = struct{}{}
+}
+
+// LastActiveAtCleared returns if the "last_active_at" field was cleared in this mutation.
+func (m *UserMutation) LastActiveAtCleared() bool {
+ _, ok := m.clearedFields[user.FieldLastActiveAt]
+ return ok
+}
+
+// ResetLastActiveAt resets all changes to the "last_active_at" field.
+func (m *UserMutation) ResetLastActiveAt() {
+ m.last_active_at = nil
+ delete(m.clearedFields, user.FieldLastActiveAt)
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (m *UserMutation) SetBalanceNotifyEnabled(b bool) {
m.balance_notify_enabled = &b
@@ -29222,6 +38546,62 @@ func (m *UserMutation) ResetTotalRecharged() {
m.addtotal_recharged = nil
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (m *UserMutation) SetRpmLimit(i int) {
+ m.rpm_limit = &i
+ m.addrpm_limit = nil
+}
+
+// RpmLimit returns the value of the "rpm_limit" field in the mutation.
+func (m *UserMutation) RpmLimit() (r int, exists bool) {
+ v := m.rpm_limit
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRpmLimit returns the old "rpm_limit" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldRpmLimit(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRpmLimit requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err)
+ }
+ return oldValue.RpmLimit, nil
+}
+
+// AddRpmLimit adds i to the "rpm_limit" field.
+func (m *UserMutation) AddRpmLimit(i int) {
+ if m.addrpm_limit != nil {
+ *m.addrpm_limit += i
+ } else {
+ m.addrpm_limit = &i
+ }
+}
+
+// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation.
+func (m *UserMutation) AddedRpmLimit() (r int, exists bool) {
+ v := m.addrpm_limit
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetRpmLimit resets all changes to the "rpm_limit" field.
+func (m *UserMutation) ResetRpmLimit() {
+ m.rpm_limit = nil
+ m.addrpm_limit = nil
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -29762,6 +39142,114 @@ func (m *UserMutation) ResetPaymentOrders() {
m.removedpayment_orders = nil
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by ids.
+func (m *UserMutation) AddAuthIdentityIDs(ids ...int64) {
+ if m.auth_identities == nil {
+ m.auth_identities = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.auth_identities[ids[i]] = struct{}{}
+ }
+}
+
+// ClearAuthIdentities clears the "auth_identities" edge to the AuthIdentity entity.
+func (m *UserMutation) ClearAuthIdentities() {
+ m.clearedauth_identities = true
+}
+
+// AuthIdentitiesCleared reports if the "auth_identities" edge to the AuthIdentity entity was cleared.
+func (m *UserMutation) AuthIdentitiesCleared() bool {
+ return m.clearedauth_identities
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (m *UserMutation) RemoveAuthIdentityIDs(ids ...int64) {
+ if m.removedauth_identities == nil {
+ m.removedauth_identities = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.auth_identities, ids[i])
+ m.removedauth_identities[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedAuthIdentities returns the removed IDs of the "auth_identities" edge to the AuthIdentity entity.
+func (m *UserMutation) RemovedAuthIdentitiesIDs() (ids []int64) {
+ for id := range m.removedauth_identities {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// AuthIdentitiesIDs returns the "auth_identities" edge IDs in the mutation.
+func (m *UserMutation) AuthIdentitiesIDs() (ids []int64) {
+ for id := range m.auth_identities {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetAuthIdentities resets all changes to the "auth_identities" edge.
+func (m *UserMutation) ResetAuthIdentities() {
+ m.auth_identities = nil
+ m.clearedauth_identities = false
+ m.removedauth_identities = nil
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by ids.
+func (m *UserMutation) AddPendingAuthSessionIDs(ids ...int64) {
+ if m.pending_auth_sessions == nil {
+ m.pending_auth_sessions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.pending_auth_sessions[ids[i]] = struct{}{}
+ }
+}
+
+// ClearPendingAuthSessions clears the "pending_auth_sessions" edge to the PendingAuthSession entity.
+func (m *UserMutation) ClearPendingAuthSessions() {
+ m.clearedpending_auth_sessions = true
+}
+
+// PendingAuthSessionsCleared reports if the "pending_auth_sessions" edge to the PendingAuthSession entity was cleared.
+func (m *UserMutation) PendingAuthSessionsCleared() bool {
+ return m.clearedpending_auth_sessions
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (m *UserMutation) RemovePendingAuthSessionIDs(ids ...int64) {
+ if m.removedpending_auth_sessions == nil {
+ m.removedpending_auth_sessions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.pending_auth_sessions, ids[i])
+ m.removedpending_auth_sessions[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedPendingAuthSessions returns the removed IDs of the "pending_auth_sessions" edge to the PendingAuthSession entity.
+func (m *UserMutation) RemovedPendingAuthSessionsIDs() (ids []int64) {
+ for id := range m.removedpending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// PendingAuthSessionsIDs returns the "pending_auth_sessions" edge IDs in the mutation.
+func (m *UserMutation) PendingAuthSessionsIDs() (ids []int64) {
+ for id := range m.pending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetPendingAuthSessions resets all changes to the "pending_auth_sessions" edge.
+func (m *UserMutation) ResetPendingAuthSessions() {
+ m.pending_auth_sessions = nil
+ m.clearedpending_auth_sessions = false
+ m.removedpending_auth_sessions = nil
+}
+
// Where appends a list predicates to the UserMutation builder.
func (m *UserMutation) Where(ps ...predicate.User) {
m.predicates = append(m.predicates, ps...)
@@ -29796,7 +39284,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 19)
+ fields := make([]string, 0, 23)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -29839,6 +39327,15 @@ func (m *UserMutation) Fields() []string {
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.signup_source != nil {
+ fields = append(fields, user.FieldSignupSource)
+ }
+ if m.last_login_at != nil {
+ fields = append(fields, user.FieldLastLoginAt)
+ }
+ if m.last_active_at != nil {
+ fields = append(fields, user.FieldLastActiveAt)
+ }
if m.balance_notify_enabled != nil {
fields = append(fields, user.FieldBalanceNotifyEnabled)
}
@@ -29854,6 +39351,9 @@ func (m *UserMutation) Fields() []string {
if m.total_recharged != nil {
fields = append(fields, user.FieldTotalRecharged)
}
+ if m.rpm_limit != nil {
+ fields = append(fields, user.FieldRpmLimit)
+ }
return fields
}
@@ -29890,6 +39390,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
+ case user.FieldSignupSource:
+ return m.SignupSource()
+ case user.FieldLastLoginAt:
+ return m.LastLoginAt()
+ case user.FieldLastActiveAt:
+ return m.LastActiveAt()
case user.FieldBalanceNotifyEnabled:
return m.BalanceNotifyEnabled()
case user.FieldBalanceNotifyThresholdType:
@@ -29900,6 +39406,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.BalanceNotifyExtraEmails()
case user.FieldTotalRecharged:
return m.TotalRecharged()
+ case user.FieldRpmLimit:
+ return m.RpmLimit()
}
return nil, false
}
@@ -29937,6 +39445,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
+ case user.FieldSignupSource:
+ return m.OldSignupSource(ctx)
+ case user.FieldLastLoginAt:
+ return m.OldLastLoginAt(ctx)
+ case user.FieldLastActiveAt:
+ return m.OldLastActiveAt(ctx)
case user.FieldBalanceNotifyEnabled:
return m.OldBalanceNotifyEnabled(ctx)
case user.FieldBalanceNotifyThresholdType:
@@ -29947,6 +39461,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldBalanceNotifyExtraEmails(ctx)
case user.FieldTotalRecharged:
return m.OldTotalRecharged(ctx)
+ case user.FieldRpmLimit:
+ return m.OldRpmLimit(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -30054,6 +39570,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetTotpEnabledAt(v)
return nil
+ case user.FieldSignupSource:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSignupSource(v)
+ return nil
+ case user.FieldLastLoginAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLastLoginAt(v)
+ return nil
+ case user.FieldLastActiveAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLastActiveAt(v)
+ return nil
case user.FieldBalanceNotifyEnabled:
v, ok := value.(bool)
if !ok {
@@ -30089,6 +39626,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetTotalRecharged(v)
return nil
+ case user.FieldRpmLimit:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRpmLimit(v)
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -30109,6 +39653,9 @@ func (m *UserMutation) AddedFields() []string {
if m.addtotal_recharged != nil {
fields = append(fields, user.FieldTotalRecharged)
}
+ if m.addrpm_limit != nil {
+ fields = append(fields, user.FieldRpmLimit)
+ }
return fields
}
@@ -30125,6 +39672,8 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedBalanceNotifyThreshold()
case user.FieldTotalRecharged:
return m.AddedTotalRecharged()
+ case user.FieldRpmLimit:
+ return m.AddedRpmLimit()
}
return nil, false
}
@@ -30162,6 +39711,13 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
}
m.AddTotalRecharged(v)
return nil
+ case user.FieldRpmLimit:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddRpmLimit(v)
+ return nil
}
return fmt.Errorf("unknown User numeric field %s", name)
}
@@ -30179,6 +39735,12 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldTotpEnabledAt) {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.FieldCleared(user.FieldLastLoginAt) {
+ fields = append(fields, user.FieldLastLoginAt)
+ }
+ if m.FieldCleared(user.FieldLastActiveAt) {
+ fields = append(fields, user.FieldLastActiveAt)
+ }
if m.FieldCleared(user.FieldBalanceNotifyThreshold) {
fields = append(fields, user.FieldBalanceNotifyThreshold)
}
@@ -30205,6 +39767,12 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldTotpEnabledAt:
m.ClearTotpEnabledAt()
return nil
+ case user.FieldLastLoginAt:
+ m.ClearLastLoginAt()
+ return nil
+ case user.FieldLastActiveAt:
+ m.ClearLastActiveAt()
+ return nil
case user.FieldBalanceNotifyThreshold:
m.ClearBalanceNotifyThreshold()
return nil
@@ -30258,6 +39826,15 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
+ case user.FieldSignupSource:
+ m.ResetSignupSource()
+ return nil
+ case user.FieldLastLoginAt:
+ m.ResetLastLoginAt()
+ return nil
+ case user.FieldLastActiveAt:
+ m.ResetLastActiveAt()
+ return nil
case user.FieldBalanceNotifyEnabled:
m.ResetBalanceNotifyEnabled()
return nil
@@ -30273,13 +39850,16 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotalRecharged:
m.ResetTotalRecharged()
return nil
+ case user.FieldRpmLimit:
+ m.ResetRpmLimit()
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
// AddedEdges returns all edge names that were set/added in this mutation.
func (m *UserMutation) AddedEdges() []string {
- edges := make([]string, 0, 10)
+ edges := make([]string, 0, 12)
if m.api_keys != nil {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -30310,6 +39890,12 @@ func (m *UserMutation) AddedEdges() []string {
if m.payment_orders != nil {
edges = append(edges, user.EdgePaymentOrders)
}
+ if m.auth_identities != nil {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.pending_auth_sessions != nil {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -30377,13 +39963,25 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case user.EdgeAuthIdentities:
+ ids := make([]ent.Value, 0, len(m.auth_identities))
+ for id := range m.auth_identities {
+ ids = append(ids, id)
+ }
+ return ids
+ case user.EdgePendingAuthSessions:
+ ids := make([]ent.Value, 0, len(m.pending_auth_sessions))
+ for id := range m.pending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
func (m *UserMutation) RemovedEdges() []string {
- edges := make([]string, 0, 10)
+ edges := make([]string, 0, 12)
if m.removedapi_keys != nil {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -30414,6 +40012,12 @@ func (m *UserMutation) RemovedEdges() []string {
if m.removedpayment_orders != nil {
edges = append(edges, user.EdgePaymentOrders)
}
+ if m.removedauth_identities != nil {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.removedpending_auth_sessions != nil {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -30481,13 +40085,25 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case user.EdgeAuthIdentities:
+ ids := make([]ent.Value, 0, len(m.removedauth_identities))
+ for id := range m.removedauth_identities {
+ ids = append(ids, id)
+ }
+ return ids
+ case user.EdgePendingAuthSessions:
+ ids := make([]ent.Value, 0, len(m.removedpending_auth_sessions))
+ for id := range m.removedpending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
func (m *UserMutation) ClearedEdges() []string {
- edges := make([]string, 0, 10)
+ edges := make([]string, 0, 12)
if m.clearedapi_keys {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -30518,6 +40134,12 @@ func (m *UserMutation) ClearedEdges() []string {
if m.clearedpayment_orders {
edges = append(edges, user.EdgePaymentOrders)
}
+ if m.clearedauth_identities {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.clearedpending_auth_sessions {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -30545,6 +40167,10 @@ func (m *UserMutation) EdgeCleared(name string) bool {
return m.clearedpromo_code_usages
case user.EdgePaymentOrders:
return m.clearedpayment_orders
+ case user.EdgeAuthIdentities:
+ return m.clearedauth_identities
+ case user.EdgePendingAuthSessions:
+ return m.clearedpending_auth_sessions
}
return false
}
@@ -30591,6 +40217,12 @@ func (m *UserMutation) ResetEdge(name string) error {
case user.EdgePaymentOrders:
m.ResetPaymentOrders()
return nil
+ case user.EdgeAuthIdentities:
+ m.ResetAuthIdentities()
+ return nil
+ case user.EdgePendingAuthSessions:
+ m.ResetPendingAuthSessions()
+ return nil
}
return fmt.Errorf("unknown User edge %s", name)
}
diff --git a/backend/ent/paymentorder.go b/backend/ent/paymentorder.go
index 6ea3e709..b131b8c8 100644
--- a/backend/ent/paymentorder.go
+++ b/backend/ent/paymentorder.go
@@ -3,6 +3,7 @@
package ent
import (
+ "encoding/json"
"fmt"
"strings"
"time"
@@ -56,6 +57,10 @@ type PaymentOrder struct {
SubscriptionDays *int `json:"subscription_days,omitempty"`
// ProviderInstanceID holds the value of the "provider_instance_id" field.
ProviderInstanceID *string `json:"provider_instance_id,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey *string `json:"provider_key,omitempty"`
+ // ProviderSnapshot holds the value of the "provider_snapshot" field.
+ ProviderSnapshot map[string]interface{} `json:"provider_snapshot,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// RefundAmount holds the value of the "refund_amount" field.
@@ -123,13 +128,15 @@ func (*PaymentOrder) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
+ case paymentorder.FieldProviderSnapshot:
+ values[i] = new([]byte)
case paymentorder.FieldForceRefund:
values[i] = new(sql.NullBool)
case paymentorder.FieldAmount, paymentorder.FieldPayAmount, paymentorder.FieldFeeRate, paymentorder.FieldRefundAmount:
values[i] = new(sql.NullFloat64)
case paymentorder.FieldID, paymentorder.FieldUserID, paymentorder.FieldPlanID, paymentorder.FieldSubscriptionGroupID, paymentorder.FieldSubscriptionDays:
values[i] = new(sql.NullInt64)
- case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL:
+ case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldProviderKey, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL:
values[i] = new(sql.NullString)
case paymentorder.FieldRefundAt, paymentorder.FieldRefundRequestedAt, paymentorder.FieldExpiresAt, paymentorder.FieldPaidAt, paymentorder.FieldCompletedAt, paymentorder.FieldFailedAt, paymentorder.FieldCreatedAt, paymentorder.FieldUpdatedAt:
values[i] = new(sql.NullTime)
@@ -276,6 +283,21 @@ func (_m *PaymentOrder) assignValues(columns []string, values []any) error {
_m.ProviderInstanceID = new(string)
*_m.ProviderInstanceID = value.String
}
+ case paymentorder.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = new(string)
+ *_m.ProviderKey = value.String
+ }
+ case paymentorder.FieldProviderSnapshot:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_snapshot", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ProviderSnapshot); err != nil {
+ return fmt.Errorf("unmarshal field provider_snapshot: %w", err)
+ }
+ }
case paymentorder.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
@@ -508,6 +530,14 @@ func (_m *PaymentOrder) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
+ if v := _m.ProviderKey; v != nil {
+ builder.WriteString("provider_key=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("provider_snapshot=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ProviderSnapshot))
+ builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
diff --git a/backend/ent/paymentorder/paymentorder.go b/backend/ent/paymentorder/paymentorder.go
index 4467b2b6..62883794 100644
--- a/backend/ent/paymentorder/paymentorder.go
+++ b/backend/ent/paymentorder/paymentorder.go
@@ -52,6 +52,10 @@ const (
FieldSubscriptionDays = "subscription_days"
// FieldProviderInstanceID holds the string denoting the provider_instance_id field in the database.
FieldProviderInstanceID = "provider_instance_id"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSnapshot holds the string denoting the provider_snapshot field in the database.
+ FieldProviderSnapshot = "provider_snapshot"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldRefundAmount holds the string denoting the refund_amount field in the database.
@@ -123,6 +127,8 @@ var Columns = []string{
FieldSubscriptionGroupID,
FieldSubscriptionDays,
FieldProviderInstanceID,
+ FieldProviderKey,
+ FieldProviderSnapshot,
FieldStatus,
FieldRefundAmount,
FieldRefundReason,
@@ -176,6 +182,8 @@ var (
OrderTypeValidator func(string) error
// ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save.
ProviderInstanceIDValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
// DefaultStatus holds the default value on creation for the "status" field.
DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
@@ -301,6 +309,11 @@ func ByProviderInstanceID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProviderInstanceID, opts...).ToFunc()
}
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
// ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
diff --git a/backend/ent/paymentorder/where.go b/backend/ent/paymentorder/where.go
index 78520fac..e96bf51e 100644
--- a/backend/ent/paymentorder/where.go
+++ b/backend/ent/paymentorder/where.go
@@ -150,6 +150,11 @@ func ProviderInstanceID(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldEQ(FieldProviderInstanceID, v))
}
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v))
+}
+
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v))
@@ -1360,6 +1365,91 @@ func ProviderInstanceIDContainsFold(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderInstanceID, v))
}
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyIsNil applies the IsNil predicate on the "provider_key" field.
+func ProviderKeyIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderKey))
+}
+
+// ProviderKeyNotNil applies the NotNil predicate on the "provider_key" field.
+func ProviderKeyNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderKey))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSnapshotIsNil applies the IsNil predicate on the "provider_snapshot" field.
+func ProviderSnapshotIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderSnapshot))
+}
+
+// ProviderSnapshotNotNil applies the NotNil predicate on the "provider_snapshot" field.
+func ProviderSnapshotNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderSnapshot))
+}
+
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.PaymentOrder {
return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v))
diff --git a/backend/ent/paymentorder_create.go b/backend/ent/paymentorder_create.go
index 03098339..3ee24f8e 100644
--- a/backend/ent/paymentorder_create.go
+++ b/backend/ent/paymentorder_create.go
@@ -225,6 +225,26 @@ func (_c *PaymentOrderCreate) SetNillableProviderInstanceID(v *string) *PaymentO
return _c
}
+// SetProviderKey sets the "provider_key" field.
+func (_c *PaymentOrderCreate) SetProviderKey(v string) *PaymentOrderCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableProviderKey(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetProviderKey(*v)
+ }
+ return _c
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_c *PaymentOrderCreate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderCreate {
+ _c.mutation.SetProviderSnapshot(v)
+ return _c
+}
+
// SetStatus sets the "status" field.
func (_c *PaymentOrderCreate) SetStatus(v string) *PaymentOrderCreate {
_c.mutation.SetStatus(v)
@@ -602,6 +622,11 @@ func (_c *PaymentOrderCreate) check() error {
return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)}
}
}
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := paymentorder.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)}
+ }
+ }
if _, ok := _c.mutation.Status(); !ok {
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "PaymentOrder.status"`)}
}
@@ -748,6 +773,14 @@ func (_c *PaymentOrderCreate) createSpec() (*PaymentOrder, *sqlgraph.CreateSpec)
_spec.SetField(paymentorder.FieldProviderInstanceID, field.TypeString, value)
_node.ProviderInstanceID = &value
}
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = &value
+ }
+ if value, ok := _c.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ _node.ProviderSnapshot = value
+ }
if value, ok := _c.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
_node.Status = value
@@ -1201,6 +1234,42 @@ func (u *PaymentOrderUpsert) ClearProviderInstanceID() *PaymentOrderUpsert {
return u
}
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentOrderUpsert) SetProviderKey(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateProviderKey() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldProviderKey)
+ return u
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (u *PaymentOrderUpsert) ClearProviderKey() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldProviderKey)
+ return u
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsert) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldProviderSnapshot, v)
+ return u
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateProviderSnapshot() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldProviderSnapshot)
+ return u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsert) ClearProviderSnapshot() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldProviderSnapshot)
+ return u
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsert) SetStatus(v string) *PaymentOrderUpsert {
u.Set(paymentorder.FieldStatus, v)
@@ -1880,6 +1949,48 @@ func (u *PaymentOrderUpsertOne) ClearProviderInstanceID() *PaymentOrderUpsertOne
})
}
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentOrderUpsertOne) SetProviderKey(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateProviderKey() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (u *PaymentOrderUpsertOne) ClearProviderKey() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderKey()
+ })
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsertOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderSnapshot(v)
+ })
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateProviderSnapshot() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderSnapshot()
+ })
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsertOne) ClearProviderSnapshot() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderSnapshot()
+ })
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsertOne) SetStatus(v string) *PaymentOrderUpsertOne {
return u.Update(func(s *PaymentOrderUpsert) {
@@ -2770,6 +2881,48 @@ func (u *PaymentOrderUpsertBulk) ClearProviderInstanceID() *PaymentOrderUpsertBu
})
}
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentOrderUpsertBulk) SetProviderKey(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateProviderKey() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (u *PaymentOrderUpsertBulk) ClearProviderKey() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderKey()
+ })
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsertBulk) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderSnapshot(v)
+ })
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateProviderSnapshot() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderSnapshot()
+ })
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsertBulk) ClearProviderSnapshot() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderSnapshot()
+ })
+}
+
// SetStatus sets the "status" field.
func (u *PaymentOrderUpsertBulk) SetStatus(v string) *PaymentOrderUpsertBulk {
return u.Update(func(s *PaymentOrderUpsert) {
diff --git a/backend/ent/paymentorder_update.go b/backend/ent/paymentorder_update.go
index 5978fc29..378e0dad 100644
--- a/backend/ent/paymentorder_update.go
+++ b/backend/ent/paymentorder_update.go
@@ -385,6 +385,38 @@ func (_u *PaymentOrderUpdate) ClearProviderInstanceID() *PaymentOrderUpdate {
return _u
}
+// SetProviderKey sets the "provider_key" field.
+func (_u *PaymentOrderUpdate) SetProviderKey(v string) *PaymentOrderUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableProviderKey(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (_u *PaymentOrderUpdate) ClearProviderKey() *PaymentOrderUpdate {
+ _u.mutation.ClearProviderKey()
+ return _u
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_u *PaymentOrderUpdate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdate {
+ _u.mutation.SetProviderSnapshot(v)
+ return _u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (_u *PaymentOrderUpdate) ClearProviderSnapshot() *PaymentOrderUpdate {
+ _u.mutation.ClearProviderSnapshot()
+ return _u
+}
+
// SetStatus sets the "status" field.
func (_u *PaymentOrderUpdate) SetStatus(v string) *PaymentOrderUpdate {
_u.mutation.SetStatus(v)
@@ -776,6 +808,11 @@ func (_u *PaymentOrderUpdate) check() error {
return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)}
}
}
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := paymentorder.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.Status(); ok {
if err := paymentorder.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)}
@@ -910,6 +947,18 @@ func (_u *PaymentOrderUpdate) sqlSave(ctx context.Context) (_node int, err error
if _u.mutation.ProviderInstanceIDCleared() {
_spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString)
}
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
+ }
+ if _u.mutation.ProviderKeyCleared() {
+ _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString)
+ }
+ if value, ok := _u.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ }
+ if _u.mutation.ProviderSnapshotCleared() {
+ _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON)
+ }
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
}
@@ -1399,6 +1448,38 @@ func (_u *PaymentOrderUpdateOne) ClearProviderInstanceID() *PaymentOrderUpdateOn
return _u
}
+// SetProviderKey sets the "provider_key" field.
+func (_u *PaymentOrderUpdateOne) SetProviderKey(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableProviderKey(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (_u *PaymentOrderUpdateOne) ClearProviderKey() *PaymentOrderUpdateOne {
+ _u.mutation.ClearProviderKey()
+ return _u
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_u *PaymentOrderUpdateOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdateOne {
+ _u.mutation.SetProviderSnapshot(v)
+ return _u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (_u *PaymentOrderUpdateOne) ClearProviderSnapshot() *PaymentOrderUpdateOne {
+ _u.mutation.ClearProviderSnapshot()
+ return _u
+}
+
// SetStatus sets the "status" field.
func (_u *PaymentOrderUpdateOne) SetStatus(v string) *PaymentOrderUpdateOne {
_u.mutation.SetStatus(v)
@@ -1803,6 +1884,11 @@ func (_u *PaymentOrderUpdateOne) check() error {
return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)}
}
}
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := paymentorder.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.Status(); ok {
if err := paymentorder.StatusValidator(v); err != nil {
return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)}
@@ -1954,6 +2040,18 @@ func (_u *PaymentOrderUpdateOne) sqlSave(ctx context.Context) (_node *PaymentOrd
if _u.mutation.ProviderInstanceIDCleared() {
_spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString)
}
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
+ }
+ if _u.mutation.ProviderKeyCleared() {
+ _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString)
+ }
+ if value, ok := _u.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ }
+ if _u.mutation.ProviderSnapshotCleared() {
+ _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON)
+ }
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
}
diff --git a/backend/ent/pendingauthsession.go b/backend/ent/pendingauthsession.go
new file mode 100644
index 00000000..e77c065f
--- /dev/null
+++ b/backend/ent/pendingauthsession.go
@@ -0,0 +1,399 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSession is the model entity for the PendingAuthSession schema.
+type PendingAuthSession struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // SessionToken holds the value of the "session_token" field.
+ SessionToken string `json:"session_token,omitempty"`
+ // Intent holds the value of the "intent" field.
+ Intent string `json:"intent,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // ProviderSubject holds the value of the "provider_subject" field.
+ ProviderSubject string `json:"provider_subject,omitempty"`
+ // TargetUserID holds the value of the "target_user_id" field.
+ TargetUserID *int64 `json:"target_user_id,omitempty"`
+ // RedirectTo holds the value of the "redirect_to" field.
+ RedirectTo string `json:"redirect_to,omitempty"`
+ // ResolvedEmail holds the value of the "resolved_email" field.
+ ResolvedEmail string `json:"resolved_email,omitempty"`
+ // RegistrationPasswordHash holds the value of the "registration_password_hash" field.
+ RegistrationPasswordHash string `json:"registration_password_hash,omitempty"`
+ // UpstreamIdentityClaims holds the value of the "upstream_identity_claims" field.
+ UpstreamIdentityClaims map[string]interface{} `json:"upstream_identity_claims,omitempty"`
+ // LocalFlowState holds the value of the "local_flow_state" field.
+ LocalFlowState map[string]interface{} `json:"local_flow_state,omitempty"`
+ // BrowserSessionKey holds the value of the "browser_session_key" field.
+ BrowserSessionKey string `json:"browser_session_key,omitempty"`
+ // CompletionCodeHash holds the value of the "completion_code_hash" field.
+ CompletionCodeHash string `json:"completion_code_hash,omitempty"`
+ // CompletionCodeExpiresAt holds the value of the "completion_code_expires_at" field.
+ CompletionCodeExpiresAt *time.Time `json:"completion_code_expires_at,omitempty"`
+ // EmailVerifiedAt holds the value of the "email_verified_at" field.
+ EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"`
+ // PasswordVerifiedAt holds the value of the "password_verified_at" field.
+ PasswordVerifiedAt *time.Time `json:"password_verified_at,omitempty"`
+ // TotpVerifiedAt holds the value of the "totp_verified_at" field.
+ TotpVerifiedAt *time.Time `json:"totp_verified_at,omitempty"`
+ // ExpiresAt holds the value of the "expires_at" field.
+ ExpiresAt time.Time `json:"expires_at,omitempty"`
+ // ConsumedAt holds the value of the "consumed_at" field.
+ ConsumedAt *time.Time `json:"consumed_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the PendingAuthSessionQuery when eager-loading is set.
+ Edges PendingAuthSessionEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// PendingAuthSessionEdges holds the relations/edges for other nodes in the graph.
+type PendingAuthSessionEdges struct {
+ // TargetUser holds the value of the target_user edge.
+ TargetUser *User `json:"target_user,omitempty"`
+ // AdoptionDecision holds the value of the adoption_decision edge.
+ AdoptionDecision *IdentityAdoptionDecision `json:"adoption_decision,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [2]bool
+}
+
+// TargetUserOrErr returns the TargetUser value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e PendingAuthSessionEdges) TargetUserOrErr() (*User, error) {
+ if e.TargetUser != nil {
+ return e.TargetUser, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: user.Label}
+ }
+ return nil, &NotLoadedError{edge: "target_user"}
+}
+
+// AdoptionDecisionOrErr returns the AdoptionDecision value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e PendingAuthSessionEdges) AdoptionDecisionOrErr() (*IdentityAdoptionDecision, error) {
+ if e.AdoptionDecision != nil {
+ return e.AdoptionDecision, nil
+ } else if e.loadedTypes[1] {
+ return nil, &NotFoundError{label: identityadoptiondecision.Label}
+ }
+ return nil, &NotLoadedError{edge: "adoption_decision"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*PendingAuthSession) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case pendingauthsession.FieldUpstreamIdentityClaims, pendingauthsession.FieldLocalFlowState:
+ values[i] = new([]byte)
+ case pendingauthsession.FieldID, pendingauthsession.FieldTargetUserID:
+ values[i] = new(sql.NullInt64)
+ case pendingauthsession.FieldSessionToken, pendingauthsession.FieldIntent, pendingauthsession.FieldProviderType, pendingauthsession.FieldProviderKey, pendingauthsession.FieldProviderSubject, pendingauthsession.FieldRedirectTo, pendingauthsession.FieldResolvedEmail, pendingauthsession.FieldRegistrationPasswordHash, pendingauthsession.FieldBrowserSessionKey, pendingauthsession.FieldCompletionCodeHash:
+ values[i] = new(sql.NullString)
+ case pendingauthsession.FieldCreatedAt, pendingauthsession.FieldUpdatedAt, pendingauthsession.FieldCompletionCodeExpiresAt, pendingauthsession.FieldEmailVerifiedAt, pendingauthsession.FieldPasswordVerifiedAt, pendingauthsession.FieldTotpVerifiedAt, pendingauthsession.FieldExpiresAt, pendingauthsession.FieldConsumedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the PendingAuthSession fields.
+func (_m *PendingAuthSession) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case pendingauthsession.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case pendingauthsession.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case pendingauthsession.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case pendingauthsession.FieldSessionToken:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field session_token", values[i])
+ } else if value.Valid {
+ _m.SessionToken = value.String
+ }
+ case pendingauthsession.FieldIntent:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field intent", values[i])
+ } else if value.Valid {
+ _m.Intent = value.String
+ }
+ case pendingauthsession.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case pendingauthsession.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case pendingauthsession.FieldProviderSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_subject", values[i])
+ } else if value.Valid {
+ _m.ProviderSubject = value.String
+ }
+ case pendingauthsession.FieldTargetUserID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field target_user_id", values[i])
+ } else if value.Valid {
+ _m.TargetUserID = new(int64)
+ *_m.TargetUserID = value.Int64
+ }
+ case pendingauthsession.FieldRedirectTo:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field redirect_to", values[i])
+ } else if value.Valid {
+ _m.RedirectTo = value.String
+ }
+ case pendingauthsession.FieldResolvedEmail:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field resolved_email", values[i])
+ } else if value.Valid {
+ _m.ResolvedEmail = value.String
+ }
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field registration_password_hash", values[i])
+ } else if value.Valid {
+ _m.RegistrationPasswordHash = value.String
+ }
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field upstream_identity_claims", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.UpstreamIdentityClaims); err != nil {
+ return fmt.Errorf("unmarshal field upstream_identity_claims: %w", err)
+ }
+ }
+ case pendingauthsession.FieldLocalFlowState:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field local_flow_state", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.LocalFlowState); err != nil {
+ return fmt.Errorf("unmarshal field local_flow_state: %w", err)
+ }
+ }
+ case pendingauthsession.FieldBrowserSessionKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field browser_session_key", values[i])
+ } else if value.Valid {
+ _m.BrowserSessionKey = value.String
+ }
+ case pendingauthsession.FieldCompletionCodeHash:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field completion_code_hash", values[i])
+ } else if value.Valid {
+ _m.CompletionCodeHash = value.String
+ }
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field completion_code_expires_at", values[i])
+ } else if value.Valid {
+ _m.CompletionCodeExpiresAt = new(time.Time)
+ *_m.CompletionCodeExpiresAt = value.Time
+ }
+ case pendingauthsession.FieldEmailVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field email_verified_at", values[i])
+ } else if value.Valid {
+ _m.EmailVerifiedAt = new(time.Time)
+ *_m.EmailVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field password_verified_at", values[i])
+ } else if value.Valid {
+ _m.PasswordVerifiedAt = new(time.Time)
+ *_m.PasswordVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldTotpVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field totp_verified_at", values[i])
+ } else if value.Valid {
+ _m.TotpVerifiedAt = new(time.Time)
+ *_m.TotpVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field expires_at", values[i])
+ } else if value.Valid {
+ _m.ExpiresAt = value.Time
+ }
+ case pendingauthsession.FieldConsumedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field consumed_at", values[i])
+ } else if value.Valid {
+ _m.ConsumedAt = new(time.Time)
+ *_m.ConsumedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the PendingAuthSession.
+// This includes values selected through modifiers, order, etc.
+func (_m *PendingAuthSession) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryTargetUser queries the "target_user" edge of the PendingAuthSession entity.
+func (_m *PendingAuthSession) QueryTargetUser() *UserQuery {
+ return NewPendingAuthSessionClient(_m.config).QueryTargetUser(_m)
+}
+
+// QueryAdoptionDecision queries the "adoption_decision" edge of the PendingAuthSession entity.
+func (_m *PendingAuthSession) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery {
+ return NewPendingAuthSessionClient(_m.config).QueryAdoptionDecision(_m)
+}
+
+// Update returns a builder for updating this PendingAuthSession.
+// Note that you need to call PendingAuthSession.Unwrap() before calling this method if this PendingAuthSession
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *PendingAuthSession) Update() *PendingAuthSessionUpdateOne {
+ return NewPendingAuthSessionClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the PendingAuthSession entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *PendingAuthSession) Unwrap() *PendingAuthSession {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: PendingAuthSession is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *PendingAuthSession) String() string {
+ var builder strings.Builder
+ builder.WriteString("PendingAuthSession(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("session_token=")
+ builder.WriteString(_m.SessionToken)
+ builder.WriteString(", ")
+ builder.WriteString("intent=")
+ builder.WriteString(_m.Intent)
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("provider_subject=")
+ builder.WriteString(_m.ProviderSubject)
+ builder.WriteString(", ")
+ if v := _m.TargetUserID; v != nil {
+ builder.WriteString("target_user_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("redirect_to=")
+ builder.WriteString(_m.RedirectTo)
+ builder.WriteString(", ")
+ builder.WriteString("resolved_email=")
+ builder.WriteString(_m.ResolvedEmail)
+ builder.WriteString(", ")
+ builder.WriteString("registration_password_hash=")
+ builder.WriteString(_m.RegistrationPasswordHash)
+ builder.WriteString(", ")
+ builder.WriteString("upstream_identity_claims=")
+ builder.WriteString(fmt.Sprintf("%v", _m.UpstreamIdentityClaims))
+ builder.WriteString(", ")
+ builder.WriteString("local_flow_state=")
+ builder.WriteString(fmt.Sprintf("%v", _m.LocalFlowState))
+ builder.WriteString(", ")
+ builder.WriteString("browser_session_key=")
+ builder.WriteString(_m.BrowserSessionKey)
+ builder.WriteString(", ")
+ builder.WriteString("completion_code_hash=")
+ builder.WriteString(_m.CompletionCodeHash)
+ builder.WriteString(", ")
+ if v := _m.CompletionCodeExpiresAt; v != nil {
+ builder.WriteString("completion_code_expires_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.EmailVerifiedAt; v != nil {
+ builder.WriteString("email_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.PasswordVerifiedAt; v != nil {
+ builder.WriteString("password_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.TotpVerifiedAt; v != nil {
+ builder.WriteString("totp_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("expires_at=")
+ builder.WriteString(_m.ExpiresAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ if v := _m.ConsumedAt; v != nil {
+ builder.WriteString("consumed_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// PendingAuthSessions is a parsable slice of PendingAuthSession.
+type PendingAuthSessions []*PendingAuthSession
diff --git a/backend/ent/pendingauthsession/pendingauthsession.go b/backend/ent/pendingauthsession/pendingauthsession.go
new file mode 100644
index 00000000..8a3ac9bf
--- /dev/null
+++ b/backend/ent/pendingauthsession/pendingauthsession.go
@@ -0,0 +1,279 @@
+// Code generated by ent, DO NOT EDIT.
+
+package pendingauthsession
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the pendingauthsession type in the database.
+ Label = "pending_auth_session"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldSessionToken holds the string denoting the session_token field in the database.
+ FieldSessionToken = "session_token"
+ // FieldIntent holds the string denoting the intent field in the database.
+ FieldIntent = "intent"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSubject holds the string denoting the provider_subject field in the database.
+ FieldProviderSubject = "provider_subject"
+ // FieldTargetUserID holds the string denoting the target_user_id field in the database.
+ FieldTargetUserID = "target_user_id"
+ // FieldRedirectTo holds the string denoting the redirect_to field in the database.
+ FieldRedirectTo = "redirect_to"
+ // FieldResolvedEmail holds the string denoting the resolved_email field in the database.
+ FieldResolvedEmail = "resolved_email"
+ // FieldRegistrationPasswordHash holds the string denoting the registration_password_hash field in the database.
+ FieldRegistrationPasswordHash = "registration_password_hash"
+ // FieldUpstreamIdentityClaims holds the string denoting the upstream_identity_claims field in the database.
+ FieldUpstreamIdentityClaims = "upstream_identity_claims"
+ // FieldLocalFlowState holds the string denoting the local_flow_state field in the database.
+ FieldLocalFlowState = "local_flow_state"
+ // FieldBrowserSessionKey holds the string denoting the browser_session_key field in the database.
+ FieldBrowserSessionKey = "browser_session_key"
+ // FieldCompletionCodeHash holds the string denoting the completion_code_hash field in the database.
+ FieldCompletionCodeHash = "completion_code_hash"
+ // FieldCompletionCodeExpiresAt holds the string denoting the completion_code_expires_at field in the database.
+ FieldCompletionCodeExpiresAt = "completion_code_expires_at"
+ // FieldEmailVerifiedAt holds the string denoting the email_verified_at field in the database.
+ FieldEmailVerifiedAt = "email_verified_at"
+ // FieldPasswordVerifiedAt holds the string denoting the password_verified_at field in the database.
+ FieldPasswordVerifiedAt = "password_verified_at"
+ // FieldTotpVerifiedAt holds the string denoting the totp_verified_at field in the database.
+ FieldTotpVerifiedAt = "totp_verified_at"
+ // FieldExpiresAt holds the string denoting the expires_at field in the database.
+ FieldExpiresAt = "expires_at"
+ // FieldConsumedAt holds the string denoting the consumed_at field in the database.
+ FieldConsumedAt = "consumed_at"
+ // EdgeTargetUser holds the string denoting the target_user edge name in mutations.
+ EdgeTargetUser = "target_user"
+ // EdgeAdoptionDecision holds the string denoting the adoption_decision edge name in mutations.
+ EdgeAdoptionDecision = "adoption_decision"
+ // Table holds the table name of the pendingauthsession in the database.
+ Table = "pending_auth_sessions"
+ // TargetUserTable is the table that holds the target_user relation/edge.
+ TargetUserTable = "pending_auth_sessions"
+ // TargetUserInverseTable is the table name for the User entity.
+ // It exists in this package in order to avoid circular dependency with the "user" package.
+ TargetUserInverseTable = "users"
+ // TargetUserColumn is the table column denoting the target_user relation/edge.
+ TargetUserColumn = "target_user_id"
+ // AdoptionDecisionTable is the table that holds the adoption_decision relation/edge.
+ AdoptionDecisionTable = "identity_adoption_decisions"
+ // AdoptionDecisionInverseTable is the table name for the IdentityAdoptionDecision entity.
+ // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package.
+ AdoptionDecisionInverseTable = "identity_adoption_decisions"
+ // AdoptionDecisionColumn is the table column denoting the adoption_decision relation/edge.
+ AdoptionDecisionColumn = "pending_auth_session_id"
+)
+
+// Columns holds all SQL columns for pendingauthsession fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldSessionToken,
+ FieldIntent,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldProviderSubject,
+ FieldTargetUserID,
+ FieldRedirectTo,
+ FieldResolvedEmail,
+ FieldRegistrationPasswordHash,
+ FieldUpstreamIdentityClaims,
+ FieldLocalFlowState,
+ FieldBrowserSessionKey,
+ FieldCompletionCodeHash,
+ FieldCompletionCodeExpiresAt,
+ FieldEmailVerifiedAt,
+ FieldPasswordVerifiedAt,
+ FieldTotpVerifiedAt,
+ FieldExpiresAt,
+ FieldConsumedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save.
+ SessionTokenValidator func(string) error
+ // IntentValidator is a validator for the "intent" field. It is called by the builders before save.
+ IntentValidator func(string) error
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ ProviderSubjectValidator func(string) error
+ // DefaultRedirectTo holds the default value on creation for the "redirect_to" field.
+ DefaultRedirectTo string
+ // DefaultResolvedEmail holds the default value on creation for the "resolved_email" field.
+ DefaultResolvedEmail string
+ // DefaultRegistrationPasswordHash holds the default value on creation for the "registration_password_hash" field.
+ DefaultRegistrationPasswordHash string
+ // DefaultUpstreamIdentityClaims holds the default value on creation for the "upstream_identity_claims" field.
+ DefaultUpstreamIdentityClaims func() map[string]interface{}
+ // DefaultLocalFlowState holds the default value on creation for the "local_flow_state" field.
+ DefaultLocalFlowState func() map[string]interface{}
+ // DefaultBrowserSessionKey holds the default value on creation for the "browser_session_key" field.
+ DefaultBrowserSessionKey string
+ // DefaultCompletionCodeHash holds the default value on creation for the "completion_code_hash" field.
+ DefaultCompletionCodeHash string
+)
+
+// OrderOption defines the ordering options for the PendingAuthSession queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// BySessionToken orders the results by the session_token field.
+func BySessionToken(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSessionToken, opts...).ToFunc()
+}
+
+// ByIntent orders the results by the intent field.
+func ByIntent(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIntent, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByProviderSubject orders the results by the provider_subject field.
+func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderSubject, opts...).ToFunc()
+}
+
+// ByTargetUserID orders the results by the target_user_id field.
+func ByTargetUserID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTargetUserID, opts...).ToFunc()
+}
+
+// ByRedirectTo orders the results by the redirect_to field.
+func ByRedirectTo(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRedirectTo, opts...).ToFunc()
+}
+
+// ByResolvedEmail orders the results by the resolved_email field.
+func ByResolvedEmail(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldResolvedEmail, opts...).ToFunc()
+}
+
+// ByRegistrationPasswordHash orders the results by the registration_password_hash field.
+func ByRegistrationPasswordHash(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRegistrationPasswordHash, opts...).ToFunc()
+}
+
+// ByBrowserSessionKey orders the results by the browser_session_key field.
+func ByBrowserSessionKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBrowserSessionKey, opts...).ToFunc()
+}
+
+// ByCompletionCodeHash orders the results by the completion_code_hash field.
+func ByCompletionCodeHash(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCompletionCodeHash, opts...).ToFunc()
+}
+
+// ByCompletionCodeExpiresAt orders the results by the completion_code_expires_at field.
+func ByCompletionCodeExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCompletionCodeExpiresAt, opts...).ToFunc()
+}
+
+// ByEmailVerifiedAt orders the results by the email_verified_at field.
+func ByEmailVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEmailVerifiedAt, opts...).ToFunc()
+}
+
+// ByPasswordVerifiedAt orders the results by the password_verified_at field.
+func ByPasswordVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPasswordVerifiedAt, opts...).ToFunc()
+}
+
+// ByTotpVerifiedAt orders the results by the totp_verified_at field.
+func ByTotpVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotpVerifiedAt, opts...).ToFunc()
+}
+
+// ByExpiresAt orders the results by the expires_at field.
+func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
+}
+
+// ByConsumedAt orders the results by the consumed_at field.
+func ByConsumedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldConsumedAt, opts...).ToFunc()
+}
+
+// ByTargetUserField orders the results by target_user field.
+func ByTargetUserField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newTargetUserStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByAdoptionDecisionField orders the results by adoption_decision field.
+func ByAdoptionDecisionField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newTargetUserStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(TargetUserInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn),
+ )
+}
+func newAdoptionDecisionStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AdoptionDecisionInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn),
+ )
+}
diff --git a/backend/ent/pendingauthsession/where.go b/backend/ent/pendingauthsession/where.go
new file mode 100644
index 00000000..cb316f47
--- /dev/null
+++ b/backend/ent/pendingauthsession/where.go
@@ -0,0 +1,1262 @@
+// Code generated by ent, DO NOT EDIT.
+
+package pendingauthsession
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// SessionToken applies equality check predicate on the "session_token" field. It's identical to SessionTokenEQ.
+func SessionToken(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v))
+}
+
+// Intent applies equality check predicate on the "intent" field. It's identical to IntentEQ.
+func Intent(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ.
+func ProviderSubject(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// TargetUserID applies equality check predicate on the "target_user_id" field. It's identical to TargetUserIDEQ.
+func TargetUserID(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v))
+}
+
+// RedirectTo applies equality check predicate on the "redirect_to" field. It's identical to RedirectToEQ.
+func RedirectTo(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v))
+}
+
+// ResolvedEmail applies equality check predicate on the "resolved_email" field. It's identical to ResolvedEmailEQ.
+func ResolvedEmail(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v))
+}
+
+// RegistrationPasswordHash applies equality check predicate on the "registration_password_hash" field. It's identical to RegistrationPasswordHashEQ.
+func RegistrationPasswordHash(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v))
+}
+
+// BrowserSessionKey applies equality check predicate on the "browser_session_key" field. It's identical to BrowserSessionKeyEQ.
+func BrowserSessionKey(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v))
+}
+
+// CompletionCodeHash applies equality check predicate on the "completion_code_hash" field. It's identical to CompletionCodeHashEQ.
+func CompletionCodeHash(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeExpiresAt applies equality check predicate on the "completion_code_expires_at" field. It's identical to CompletionCodeExpiresAtEQ.
+func CompletionCodeExpiresAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// EmailVerifiedAt applies equality check predicate on the "email_verified_at" field. It's identical to EmailVerifiedAtEQ.
+func EmailVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v))
+}
+
+// PasswordVerifiedAt applies equality check predicate on the "password_verified_at" field. It's identical to PasswordVerifiedAtEQ.
+func PasswordVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v))
+}
+
+// TotpVerifiedAt applies equality check predicate on the "totp_verified_at" field. It's identical to TotpVerifiedAtEQ.
+func TotpVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v))
+}
+
+// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
+func ExpiresAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ConsumedAt applies equality check predicate on the "consumed_at" field. It's identical to ConsumedAtEQ.
+func ConsumedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// SessionTokenEQ applies the EQ predicate on the "session_token" field.
+func SessionTokenEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v))
+}
+
+// SessionTokenNEQ applies the NEQ predicate on the "session_token" field.
+func SessionTokenNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldSessionToken, v))
+}
+
+// SessionTokenIn applies the In predicate on the "session_token" field.
+func SessionTokenIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldSessionToken, vs...))
+}
+
+// SessionTokenNotIn applies the NotIn predicate on the "session_token" field.
+func SessionTokenNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldSessionToken, vs...))
+}
+
+// SessionTokenGT applies the GT predicate on the "session_token" field.
+func SessionTokenGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldSessionToken, v))
+}
+
+// SessionTokenGTE applies the GTE predicate on the "session_token" field.
+func SessionTokenGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldSessionToken, v))
+}
+
+// SessionTokenLT applies the LT predicate on the "session_token" field.
+func SessionTokenLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldSessionToken, v))
+}
+
+// SessionTokenLTE applies the LTE predicate on the "session_token" field.
+func SessionTokenLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldSessionToken, v))
+}
+
+// SessionTokenContains applies the Contains predicate on the "session_token" field.
+func SessionTokenContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldSessionToken, v))
+}
+
+// SessionTokenHasPrefix applies the HasPrefix predicate on the "session_token" field.
+func SessionTokenHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldSessionToken, v))
+}
+
+// SessionTokenHasSuffix applies the HasSuffix predicate on the "session_token" field.
+func SessionTokenHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldSessionToken, v))
+}
+
+// SessionTokenEqualFold applies the EqualFold predicate on the "session_token" field.
+func SessionTokenEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldSessionToken, v))
+}
+
+// SessionTokenContainsFold applies the ContainsFold predicate on the "session_token" field.
+func SessionTokenContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldSessionToken, v))
+}
+
+// IntentEQ applies the EQ predicate on the "intent" field.
+func IntentEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v))
+}
+
+// IntentNEQ applies the NEQ predicate on the "intent" field.
+func IntentNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldIntent, v))
+}
+
+// IntentIn applies the In predicate on the "intent" field.
+func IntentIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldIntent, vs...))
+}
+
+// IntentNotIn applies the NotIn predicate on the "intent" field.
+func IntentNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldIntent, vs...))
+}
+
+// IntentGT applies the GT predicate on the "intent" field.
+func IntentGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldIntent, v))
+}
+
+// IntentGTE applies the GTE predicate on the "intent" field.
+func IntentGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldIntent, v))
+}
+
+// IntentLT applies the LT predicate on the "intent" field.
+func IntentLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldIntent, v))
+}
+
+// IntentLTE applies the LTE predicate on the "intent" field.
+func IntentLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldIntent, v))
+}
+
+// IntentContains applies the Contains predicate on the "intent" field.
+func IntentContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldIntent, v))
+}
+
+// IntentHasPrefix applies the HasPrefix predicate on the "intent" field.
+func IntentHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldIntent, v))
+}
+
+// IntentHasSuffix applies the HasSuffix predicate on the "intent" field.
+func IntentHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldIntent, v))
+}
+
+// IntentEqualFold applies the EqualFold predicate on the "intent" field.
+func IntentEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldIntent, v))
+}
+
+// IntentContainsFold applies the ContainsFold predicate on the "intent" field.
+func IntentContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldIntent, v))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field.
+func ProviderSubjectEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field.
+func ProviderSubjectNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectIn applies the In predicate on the "provider_subject" field.
+func ProviderSubjectIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field.
+func ProviderSubjectNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectGT applies the GT predicate on the "provider_subject" field.
+func ProviderSubjectGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field.
+func ProviderSubjectGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLT applies the LT predicate on the "provider_subject" field.
+func ProviderSubjectLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field.
+func ProviderSubjectLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field.
+func ProviderSubjectContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field.
+func ProviderSubjectHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field.
+func ProviderSubjectHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field.
+func ProviderSubjectEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field.
+func ProviderSubjectContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderSubject, v))
+}
+
+// TargetUserIDEQ applies the EQ predicate on the "target_user_id" field.
+func TargetUserIDEQ(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v))
+}
+
+// TargetUserIDNEQ applies the NEQ predicate on the "target_user_id" field.
+func TargetUserIDNEQ(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldTargetUserID, v))
+}
+
+// TargetUserIDIn applies the In predicate on the "target_user_id" field.
+func TargetUserIDIn(vs ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldTargetUserID, vs...))
+}
+
+// TargetUserIDNotIn applies the NotIn predicate on the "target_user_id" field.
+func TargetUserIDNotIn(vs ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldTargetUserID, vs...))
+}
+
+// TargetUserIDIsNil applies the IsNil predicate on the "target_user_id" field.
+func TargetUserIDIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldTargetUserID))
+}
+
+// TargetUserIDNotNil applies the NotNil predicate on the "target_user_id" field.
+func TargetUserIDNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldTargetUserID))
+}
+
+// RedirectToEQ applies the EQ predicate on the "redirect_to" field.
+func RedirectToEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v))
+}
+
+// RedirectToNEQ applies the NEQ predicate on the "redirect_to" field.
+func RedirectToNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldRedirectTo, v))
+}
+
+// RedirectToIn applies the In predicate on the "redirect_to" field.
+func RedirectToIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldRedirectTo, vs...))
+}
+
+// RedirectToNotIn applies the NotIn predicate on the "redirect_to" field.
+func RedirectToNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldRedirectTo, vs...))
+}
+
+// RedirectToGT applies the GT predicate on the "redirect_to" field.
+func RedirectToGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldRedirectTo, v))
+}
+
+// RedirectToGTE applies the GTE predicate on the "redirect_to" field.
+func RedirectToGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldRedirectTo, v))
+}
+
+// RedirectToLT applies the LT predicate on the "redirect_to" field.
+func RedirectToLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldRedirectTo, v))
+}
+
+// RedirectToLTE applies the LTE predicate on the "redirect_to" field.
+func RedirectToLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldRedirectTo, v))
+}
+
+// RedirectToContains applies the Contains predicate on the "redirect_to" field.
+func RedirectToContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldRedirectTo, v))
+}
+
+// RedirectToHasPrefix applies the HasPrefix predicate on the "redirect_to" field.
+func RedirectToHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRedirectTo, v))
+}
+
+// RedirectToHasSuffix applies the HasSuffix predicate on the "redirect_to" field.
+func RedirectToHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRedirectTo, v))
+}
+
+// RedirectToEqualFold applies the EqualFold predicate on the "redirect_to" field.
+func RedirectToEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRedirectTo, v))
+}
+
+// RedirectToContainsFold applies the ContainsFold predicate on the "redirect_to" field.
+func RedirectToContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRedirectTo, v))
+}
+
+// ResolvedEmailEQ applies the EQ predicate on the "resolved_email" field.
+func ResolvedEmailEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailNEQ applies the NEQ predicate on the "resolved_email" field.
+func ResolvedEmailNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailIn applies the In predicate on the "resolved_email" field.
+func ResolvedEmailIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldResolvedEmail, vs...))
+}
+
+// ResolvedEmailNotIn applies the NotIn predicate on the "resolved_email" field.
+func ResolvedEmailNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldResolvedEmail, vs...))
+}
+
+// ResolvedEmailGT applies the GT predicate on the "resolved_email" field.
+func ResolvedEmailGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailGTE applies the GTE predicate on the "resolved_email" field.
+func ResolvedEmailGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailLT applies the LT predicate on the "resolved_email" field.
+func ResolvedEmailLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailLTE applies the LTE predicate on the "resolved_email" field.
+func ResolvedEmailLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailContains applies the Contains predicate on the "resolved_email" field.
+func ResolvedEmailContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailHasPrefix applies the HasPrefix predicate on the "resolved_email" field.
+func ResolvedEmailHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailHasSuffix applies the HasSuffix predicate on the "resolved_email" field.
+func ResolvedEmailHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailEqualFold applies the EqualFold predicate on the "resolved_email" field.
+func ResolvedEmailEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailContainsFold applies the ContainsFold predicate on the "resolved_email" field.
+func ResolvedEmailContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldResolvedEmail, v))
+}
+
+// RegistrationPasswordHashEQ applies the EQ predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashNEQ applies the NEQ predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashIn applies the In predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldRegistrationPasswordHash, vs...))
+}
+
+// RegistrationPasswordHashNotIn applies the NotIn predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldRegistrationPasswordHash, vs...))
+}
+
+// RegistrationPasswordHashGT applies the GT predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashGTE applies the GTE predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashLT applies the LT predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashLTE applies the LTE predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashContains applies the Contains predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashHasPrefix applies the HasPrefix predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashHasSuffix applies the HasSuffix predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashEqualFold applies the EqualFold predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashContainsFold applies the ContainsFold predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRegistrationPasswordHash, v))
+}
+
+// BrowserSessionKeyEQ applies the EQ predicate on the "browser_session_key" field.
+func BrowserSessionKeyEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyNEQ applies the NEQ predicate on the "browser_session_key" field.
+func BrowserSessionKeyNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyIn applies the In predicate on the "browser_session_key" field.
+func BrowserSessionKeyIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldBrowserSessionKey, vs...))
+}
+
+// BrowserSessionKeyNotIn applies the NotIn predicate on the "browser_session_key" field.
+func BrowserSessionKeyNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldBrowserSessionKey, vs...))
+}
+
+// BrowserSessionKeyGT applies the GT predicate on the "browser_session_key" field.
+func BrowserSessionKeyGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyGTE applies the GTE predicate on the "browser_session_key" field.
+func BrowserSessionKeyGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyLT applies the LT predicate on the "browser_session_key" field.
+func BrowserSessionKeyLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyLTE applies the LTE predicate on the "browser_session_key" field.
+func BrowserSessionKeyLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyContains applies the Contains predicate on the "browser_session_key" field.
+func BrowserSessionKeyContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyHasPrefix applies the HasPrefix predicate on the "browser_session_key" field.
+func BrowserSessionKeyHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyHasSuffix applies the HasSuffix predicate on the "browser_session_key" field.
+func BrowserSessionKeyHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyEqualFold applies the EqualFold predicate on the "browser_session_key" field.
+func BrowserSessionKeyEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyContainsFold applies the ContainsFold predicate on the "browser_session_key" field.
+func BrowserSessionKeyContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldBrowserSessionKey, v))
+}
+
+// CompletionCodeHashEQ applies the EQ predicate on the "completion_code_hash" field.
+func CompletionCodeHashEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashNEQ applies the NEQ predicate on the "completion_code_hash" field.
+func CompletionCodeHashNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashIn applies the In predicate on the "completion_code_hash" field.
+func CompletionCodeHashIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeHash, vs...))
+}
+
+// CompletionCodeHashNotIn applies the NotIn predicate on the "completion_code_hash" field.
+func CompletionCodeHashNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeHash, vs...))
+}
+
+// CompletionCodeHashGT applies the GT predicate on the "completion_code_hash" field.
+func CompletionCodeHashGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashGTE applies the GTE predicate on the "completion_code_hash" field.
+func CompletionCodeHashGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashLT applies the LT predicate on the "completion_code_hash" field.
+func CompletionCodeHashLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashLTE applies the LTE predicate on the "completion_code_hash" field.
+func CompletionCodeHashLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashContains applies the Contains predicate on the "completion_code_hash" field.
+func CompletionCodeHashContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashHasPrefix applies the HasPrefix predicate on the "completion_code_hash" field.
+func CompletionCodeHashHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashHasSuffix applies the HasSuffix predicate on the "completion_code_hash" field.
+func CompletionCodeHashHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashEqualFold applies the EqualFold predicate on the "completion_code_hash" field.
+func CompletionCodeHashEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashContainsFold applies the ContainsFold predicate on the "completion_code_hash" field.
+func CompletionCodeHashContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeExpiresAtEQ applies the EQ predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtNEQ applies the NEQ predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtIn applies the In predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeExpiresAt, vs...))
+}
+
+// CompletionCodeExpiresAtNotIn applies the NotIn predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeExpiresAt, vs...))
+}
+
+// CompletionCodeExpiresAtGT applies the GT predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtGTE applies the GTE predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtLT applies the LT predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtLTE applies the LTE predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtIsNil applies the IsNil predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldCompletionCodeExpiresAt))
+}
+
+// CompletionCodeExpiresAtNotNil applies the NotNil predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldCompletionCodeExpiresAt))
+}
+
+// EmailVerifiedAtEQ applies the EQ predicate on the "email_verified_at" field.
+func EmailVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtNEQ applies the NEQ predicate on the "email_verified_at" field.
+func EmailVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtIn applies the In predicate on the "email_verified_at" field.
+func EmailVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldEmailVerifiedAt, vs...))
+}
+
+// EmailVerifiedAtNotIn applies the NotIn predicate on the "email_verified_at" field.
+func EmailVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldEmailVerifiedAt, vs...))
+}
+
+// EmailVerifiedAtGT applies the GT predicate on the "email_verified_at" field.
+func EmailVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtGTE applies the GTE predicate on the "email_verified_at" field.
+func EmailVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtLT applies the LT predicate on the "email_verified_at" field.
+func EmailVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtLTE applies the LTE predicate on the "email_verified_at" field.
+func EmailVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtIsNil applies the IsNil predicate on the "email_verified_at" field.
+func EmailVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldEmailVerifiedAt))
+}
+
+// EmailVerifiedAtNotNil applies the NotNil predicate on the "email_verified_at" field.
+func EmailVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldEmailVerifiedAt))
+}
+
+// PasswordVerifiedAtEQ applies the EQ predicate on the "password_verified_at" field.
+func PasswordVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtNEQ applies the NEQ predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtIn applies the In predicate on the "password_verified_at" field.
+func PasswordVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldPasswordVerifiedAt, vs...))
+}
+
+// PasswordVerifiedAtNotIn applies the NotIn predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldPasswordVerifiedAt, vs...))
+}
+
+// PasswordVerifiedAtGT applies the GT predicate on the "password_verified_at" field.
+func PasswordVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtGTE applies the GTE predicate on the "password_verified_at" field.
+func PasswordVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtLT applies the LT predicate on the "password_verified_at" field.
+func PasswordVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtLTE applies the LTE predicate on the "password_verified_at" field.
+func PasswordVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtIsNil applies the IsNil predicate on the "password_verified_at" field.
+func PasswordVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldPasswordVerifiedAt))
+}
+
+// PasswordVerifiedAtNotNil applies the NotNil predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldPasswordVerifiedAt))
+}
+
+// TotpVerifiedAtEQ applies the EQ predicate on the "totp_verified_at" field.
+func TotpVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtNEQ applies the NEQ predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtIn applies the In predicate on the "totp_verified_at" field.
+func TotpVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldTotpVerifiedAt, vs...))
+}
+
+// TotpVerifiedAtNotIn applies the NotIn predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldTotpVerifiedAt, vs...))
+}
+
+// TotpVerifiedAtGT applies the GT predicate on the "totp_verified_at" field.
+func TotpVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtGTE applies the GTE predicate on the "totp_verified_at" field.
+func TotpVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtLT applies the LT predicate on the "totp_verified_at" field.
+func TotpVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtLTE applies the LTE predicate on the "totp_verified_at" field.
+func TotpVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtIsNil applies the IsNil predicate on the "totp_verified_at" field.
+func TotpVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldTotpVerifiedAt))
+}
+
+// TotpVerifiedAtNotNil applies the NotNil predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldTotpVerifiedAt))
+}
+
+// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
+func ExpiresAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
+func ExpiresAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtIn applies the In predicate on the "expires_at" field.
+func ExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
+func ExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtGT applies the GT predicate on the "expires_at" field.
+func ExpiresAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldExpiresAt, v))
+}
+
+// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
+func ExpiresAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldExpiresAt, v))
+}
+
+// ExpiresAtLT applies the LT predicate on the "expires_at" field.
+func ExpiresAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldExpiresAt, v))
+}
+
+// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
+func ExpiresAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldExpiresAt, v))
+}
+
+// ConsumedAtEQ applies the EQ predicate on the "consumed_at" field.
+func ConsumedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v))
+}
+
+// ConsumedAtNEQ applies the NEQ predicate on the "consumed_at" field.
+func ConsumedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldConsumedAt, v))
+}
+
+// ConsumedAtIn applies the In predicate on the "consumed_at" field.
+func ConsumedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldConsumedAt, vs...))
+}
+
+// ConsumedAtNotIn applies the NotIn predicate on the "consumed_at" field.
+func ConsumedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldConsumedAt, vs...))
+}
+
+// ConsumedAtGT applies the GT predicate on the "consumed_at" field.
+func ConsumedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldConsumedAt, v))
+}
+
+// ConsumedAtGTE applies the GTE predicate on the "consumed_at" field.
+func ConsumedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldConsumedAt, v))
+}
+
+// ConsumedAtLT applies the LT predicate on the "consumed_at" field.
+func ConsumedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldConsumedAt, v))
+}
+
+// ConsumedAtLTE applies the LTE predicate on the "consumed_at" field.
+func ConsumedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldConsumedAt, v))
+}
+
+// ConsumedAtIsNil applies the IsNil predicate on the "consumed_at" field.
+func ConsumedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldConsumedAt))
+}
+
+// ConsumedAtNotNil applies the NotNil predicate on the "consumed_at" field.
+func ConsumedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldConsumedAt))
+}
+
+// HasTargetUser applies the HasEdge predicate on the "target_user" edge.
+func HasTargetUser() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasTargetUserWith applies the HasEdge predicate on the "target_user" edge with a given conditions (other predicates).
+func HasTargetUserWith(preds ...predicate.User) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := newTargetUserStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasAdoptionDecision applies the HasEdge predicate on the "adoption_decision" edge.
+func HasAdoptionDecision() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAdoptionDecisionWith applies the HasEdge predicate on the "adoption_decision" edge with a given conditions (other predicates).
+func HasAdoptionDecisionWith(preds ...predicate.IdentityAdoptionDecision) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := newAdoptionDecisionStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.NotPredicates(p))
+}
diff --git a/backend/ent/pendingauthsession_create.go b/backend/ent/pendingauthsession_create.go
new file mode 100644
index 00000000..60276daa
--- /dev/null
+++ b/backend/ent/pendingauthsession_create.go
@@ -0,0 +1,1815 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionCreate is the builder for creating a PendingAuthSession entity.
+type PendingAuthSessionCreate struct {
+ config
+ mutation *PendingAuthSessionMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *PendingAuthSessionCreate) SetCreatedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCreatedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *PendingAuthSessionCreate) SetUpdatedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableUpdatedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_c *PendingAuthSessionCreate) SetSessionToken(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetSessionToken(v)
+ return _c
+}
+
+// SetIntent sets the "intent" field.
+func (_c *PendingAuthSessionCreate) SetIntent(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetIntent(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *PendingAuthSessionCreate) SetProviderType(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *PendingAuthSessionCreate) SetProviderKey(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_c *PendingAuthSessionCreate) SetProviderSubject(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderSubject(v)
+ return _c
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_c *PendingAuthSessionCreate) SetTargetUserID(v int64) *PendingAuthSessionCreate {
+ _c.mutation.SetTargetUserID(v)
+ return _c
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableTargetUserID(v *int64) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetTargetUserID(*v)
+ }
+ return _c
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_c *PendingAuthSessionCreate) SetRedirectTo(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetRedirectTo(v)
+ return _c
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableRedirectTo(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetRedirectTo(*v)
+ }
+ return _c
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_c *PendingAuthSessionCreate) SetResolvedEmail(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetResolvedEmail(v)
+ return _c
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableResolvedEmail(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetResolvedEmail(*v)
+ }
+ return _c
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_c *PendingAuthSessionCreate) SetRegistrationPasswordHash(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetRegistrationPasswordHash(v)
+ return _c
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetRegistrationPasswordHash(*v)
+ }
+ return _c
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_c *PendingAuthSessionCreate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionCreate {
+ _c.mutation.SetUpstreamIdentityClaims(v)
+ return _c
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_c *PendingAuthSessionCreate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionCreate {
+ _c.mutation.SetLocalFlowState(v)
+ return _c
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_c *PendingAuthSessionCreate) SetBrowserSessionKey(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetBrowserSessionKey(v)
+ return _c
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetBrowserSessionKey(*v)
+ }
+ return _c
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_c *PendingAuthSessionCreate) SetCompletionCodeHash(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetCompletionCodeHash(v)
+ return _c
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCompletionCodeHash(*v)
+ }
+ return _c
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_c *PendingAuthSessionCreate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetCompletionCodeExpiresAt(v)
+ return _c
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCompletionCodeExpiresAt(*v)
+ }
+ return _c
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetEmailVerifiedAt(v)
+ return _c
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetEmailVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetPasswordVerifiedAt(v)
+ return _c
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetPasswordVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetTotpVerifiedAt(v)
+ return _c
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetTotpVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_c *PendingAuthSessionCreate) SetExpiresAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetExpiresAt(v)
+ return _c
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_c *PendingAuthSessionCreate) SetConsumedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetConsumedAt(v)
+ return _c
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetConsumedAt(*v)
+ }
+ return _c
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_c *PendingAuthSessionCreate) SetTargetUser(v *User) *PendingAuthSessionCreate {
+ return _c.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_c *PendingAuthSessionCreate) SetAdoptionDecisionID(id int64) *PendingAuthSessionCreate {
+ _c.mutation.SetAdoptionDecisionID(id)
+ return _c
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionCreate {
+ if id != nil {
+ _c = _c.SetAdoptionDecisionID(*id)
+ }
+ return _c
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_c *PendingAuthSessionCreate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionCreate {
+ return _c.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_c *PendingAuthSessionCreate) Mutation() *PendingAuthSessionMutation {
+ return _c.mutation
+}
+
+// Save creates the PendingAuthSession in the database.
+func (_c *PendingAuthSessionCreate) Save(ctx context.Context) (*PendingAuthSession, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *PendingAuthSessionCreate) SaveX(ctx context.Context) *PendingAuthSession {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PendingAuthSessionCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PendingAuthSessionCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *PendingAuthSessionCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := pendingauthsession.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.RedirectTo(); !ok {
+ v := pendingauthsession.DefaultRedirectTo
+ _c.mutation.SetRedirectTo(v)
+ }
+ if _, ok := _c.mutation.ResolvedEmail(); !ok {
+ v := pendingauthsession.DefaultResolvedEmail
+ _c.mutation.SetResolvedEmail(v)
+ }
+ if _, ok := _c.mutation.RegistrationPasswordHash(); !ok {
+ v := pendingauthsession.DefaultRegistrationPasswordHash
+ _c.mutation.SetRegistrationPasswordHash(v)
+ }
+ if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok {
+ v := pendingauthsession.DefaultUpstreamIdentityClaims()
+ _c.mutation.SetUpstreamIdentityClaims(v)
+ }
+ if _, ok := _c.mutation.LocalFlowState(); !ok {
+ v := pendingauthsession.DefaultLocalFlowState()
+ _c.mutation.SetLocalFlowState(v)
+ }
+ if _, ok := _c.mutation.BrowserSessionKey(); !ok {
+ v := pendingauthsession.DefaultBrowserSessionKey
+ _c.mutation.SetBrowserSessionKey(v)
+ }
+ if _, ok := _c.mutation.CompletionCodeHash(); !ok {
+ v := pendingauthsession.DefaultCompletionCodeHash
+ _c.mutation.SetCompletionCodeHash(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *PendingAuthSessionCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PendingAuthSession.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "PendingAuthSession.updated_at"`)}
+ }
+ if _, ok := _c.mutation.SessionToken(); !ok {
+ return &ValidationError{Name: "session_token", err: errors.New(`ent: missing required field "PendingAuthSession.session_token"`)}
+ }
+ if v, ok := _c.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Intent(); !ok {
+ return &ValidationError{Name: "intent", err: errors.New(`ent: missing required field "PendingAuthSession.intent"`)}
+ }
+ if v, ok := _c.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "PendingAuthSession.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "PendingAuthSession.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderSubject(); !ok {
+ return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "PendingAuthSession.provider_subject"`)}
+ }
+ if v, ok := _c.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.RedirectTo(); !ok {
+ return &ValidationError{Name: "redirect_to", err: errors.New(`ent: missing required field "PendingAuthSession.redirect_to"`)}
+ }
+ if _, ok := _c.mutation.ResolvedEmail(); !ok {
+ return &ValidationError{Name: "resolved_email", err: errors.New(`ent: missing required field "PendingAuthSession.resolved_email"`)}
+ }
+ if _, ok := _c.mutation.RegistrationPasswordHash(); !ok {
+ return &ValidationError{Name: "registration_password_hash", err: errors.New(`ent: missing required field "PendingAuthSession.registration_password_hash"`)}
+ }
+ if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok {
+ return &ValidationError{Name: "upstream_identity_claims", err: errors.New(`ent: missing required field "PendingAuthSession.upstream_identity_claims"`)}
+ }
+ if _, ok := _c.mutation.LocalFlowState(); !ok {
+ return &ValidationError{Name: "local_flow_state", err: errors.New(`ent: missing required field "PendingAuthSession.local_flow_state"`)}
+ }
+ if _, ok := _c.mutation.BrowserSessionKey(); !ok {
+ return &ValidationError{Name: "browser_session_key", err: errors.New(`ent: missing required field "PendingAuthSession.browser_session_key"`)}
+ }
+ if _, ok := _c.mutation.CompletionCodeHash(); !ok {
+ return &ValidationError{Name: "completion_code_hash", err: errors.New(`ent: missing required field "PendingAuthSession.completion_code_hash"`)}
+ }
+ if _, ok := _c.mutation.ExpiresAt(); !ok {
+ return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "PendingAuthSession.expires_at"`)}
+ }
+ return nil
+}
+
+func (_c *PendingAuthSessionCreate) sqlSave(ctx context.Context) (*PendingAuthSession, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *PendingAuthSessionCreate) createSpec() (*PendingAuthSession, *sqlgraph.CreateSpec) {
+ var (
+ _node = &PendingAuthSession{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ _node.SessionToken = value
+ }
+ if value, ok := _c.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ _node.Intent = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ _node.ProviderSubject = value
+ }
+ if value, ok := _c.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ _node.RedirectTo = value
+ }
+ if value, ok := _c.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ _node.ResolvedEmail = value
+ }
+ if value, ok := _c.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ _node.RegistrationPasswordHash = value
+ }
+ if value, ok := _c.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ _node.UpstreamIdentityClaims = value
+ }
+ if value, ok := _c.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ _node.LocalFlowState = value
+ }
+ if value, ok := _c.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ _node.BrowserSessionKey = value
+ }
+ if value, ok := _c.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ _node.CompletionCodeHash = value
+ }
+ if value, ok := _c.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ _node.CompletionCodeExpiresAt = &value
+ }
+ if value, ok := _c.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ _node.EmailVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ _node.PasswordVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ _node.TotpVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ _node.ExpiresAt = value
+ }
+ if value, ok := _c.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ _node.ConsumedAt = &value
+ }
+ if nodes := _c.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.TargetUserID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PendingAuthSession.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PendingAuthSessionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreate) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertOne {
+ _c.conflict = opts
+ return &PendingAuthSessionUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreate) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PendingAuthSessionUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // PendingAuthSessionUpsertOne is the builder for "upsert"-ing
+ // one PendingAuthSession node.
+ PendingAuthSessionUpsertOne struct {
+ create *PendingAuthSessionCreate
+ }
+
+ // PendingAuthSessionUpsert is the "OnConflict" setter.
+ PendingAuthSessionUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsert) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateUpdatedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldUpdatedAt)
+ return u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsert) SetSessionToken(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldSessionToken, v)
+ return u
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateSessionToken() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldSessionToken)
+ return u
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsert) SetIntent(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldIntent, v)
+ return u
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateIntent() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldIntent)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsert) SetProviderType(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderType() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsert) SetProviderKey(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderKey() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderKey)
+ return u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsert) SetProviderSubject(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderSubject, v)
+ return u
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderSubject() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderSubject)
+ return u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsert) SetTargetUserID(v int64) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldTargetUserID, v)
+ return u
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateTargetUserID() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldTargetUserID)
+ return u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsert) ClearTargetUserID() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldTargetUserID)
+ return u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsert) SetRedirectTo(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldRedirectTo, v)
+ return u
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateRedirectTo() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldRedirectTo)
+ return u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsert) SetResolvedEmail(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldResolvedEmail, v)
+ return u
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateResolvedEmail() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldResolvedEmail)
+ return u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsert) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldRegistrationPasswordHash, v)
+ return u
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldRegistrationPasswordHash)
+ return u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsert) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldUpstreamIdentityClaims, v)
+ return u
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldUpstreamIdentityClaims)
+ return u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsert) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldLocalFlowState, v)
+ return u
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateLocalFlowState() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldLocalFlowState)
+ return u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsert) SetBrowserSessionKey(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldBrowserSessionKey, v)
+ return u
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateBrowserSessionKey() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldBrowserSessionKey)
+ return u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsert) SetCompletionCodeHash(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldCompletionCodeHash, v)
+ return u
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateCompletionCodeHash() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldCompletionCodeHash)
+ return u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsert) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldCompletionCodeExpiresAt, v)
+ return u
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldCompletionCodeExpiresAt)
+ return u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsert) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldCompletionCodeExpiresAt)
+ return u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldEmailVerifiedAt, v)
+ return u
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateEmailVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldEmailVerifiedAt)
+ return u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearEmailVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldEmailVerifiedAt)
+ return u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldPasswordVerifiedAt, v)
+ return u
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldPasswordVerifiedAt)
+ return u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearPasswordVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldPasswordVerifiedAt)
+ return u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldTotpVerifiedAt, v)
+ return u
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateTotpVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldTotpVerifiedAt)
+ return u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearTotpVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldTotpVerifiedAt)
+ return u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsert) SetExpiresAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldExpiresAt, v)
+ return u
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateExpiresAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldExpiresAt)
+ return u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsert) SetConsumedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldConsumedAt, v)
+ return u
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateConsumedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldConsumedAt)
+ return u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsert) ClearConsumedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldConsumedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertOne) UpdateNewValues() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(pendingauthsession.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertOne) Ignore() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PendingAuthSessionUpsertOne) DoNothing() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreate.OnConflict
+// documentation for more info.
+func (u *PendingAuthSessionUpsertOne) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PendingAuthSessionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsertOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateUpdatedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsertOne) SetSessionToken(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetSessionToken(v)
+ })
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateSessionToken() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateSessionToken()
+ })
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsertOne) SetIntent(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetIntent(v)
+ })
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateIntent() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateIntent()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderType(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderType() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderKey(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderKey() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderSubject(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderSubject() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsertOne) SetTargetUserID(v int64) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTargetUserID(v)
+ })
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateTargetUserID() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTargetUserID()
+ })
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsertOne) ClearTargetUserID() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTargetUserID()
+ })
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsertOne) SetRedirectTo(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRedirectTo(v)
+ })
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateRedirectTo() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRedirectTo()
+ })
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsertOne) SetResolvedEmail(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetResolvedEmail(v)
+ })
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateResolvedEmail() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateResolvedEmail()
+ })
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsertOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRegistrationPasswordHash(v)
+ })
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRegistrationPasswordHash()
+ })
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsertOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpstreamIdentityClaims(v)
+ })
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpstreamIdentityClaims()
+ })
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsertOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetLocalFlowState(v)
+ })
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateLocalFlowState() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateLocalFlowState()
+ })
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsertOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetBrowserSessionKey(v)
+ })
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateBrowserSessionKey() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateBrowserSessionKey()
+ })
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsertOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeHash(v)
+ })
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeHash() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeHash()
+ })
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeExpiresAt(v)
+ })
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeExpiresAt()
+ })
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearCompletionCodeExpiresAt()
+ })
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetEmailVerifiedAt(v)
+ })
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateEmailVerifiedAt()
+ })
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearEmailVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearEmailVerifiedAt()
+ })
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetPasswordVerifiedAt(v)
+ })
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdatePasswordVerifiedAt()
+ })
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearPasswordVerifiedAt()
+ })
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTotpVerifiedAt(v)
+ })
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTotpVerifiedAt()
+ })
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearTotpVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTotpVerifiedAt()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsertOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsertOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetConsumedAt(v)
+ })
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateConsumedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateConsumedAt()
+ })
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearConsumedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearConsumedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PendingAuthSessionUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PendingAuthSessionCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *PendingAuthSessionUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// PendingAuthSessionCreateBulk is the builder for creating many PendingAuthSession entities in bulk.
+type PendingAuthSessionCreateBulk struct {
+ config
+ err error
+ builders []*PendingAuthSessionCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the PendingAuthSession entities in the database.
+func (_c *PendingAuthSessionCreateBulk) Save(ctx context.Context) ([]*PendingAuthSession, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*PendingAuthSession, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*PendingAuthSessionMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *PendingAuthSessionCreateBulk) SaveX(ctx context.Context) []*PendingAuthSession {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PendingAuthSessionCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PendingAuthSessionCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PendingAuthSession.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PendingAuthSessionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreateBulk) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertBulk {
+ _c.conflict = opts
+ return &PendingAuthSessionUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreateBulk) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PendingAuthSessionUpsertBulk{
+ create: _c,
+ }
+}
+
+// PendingAuthSessionUpsertBulk is the builder for "upsert"-ing
+// a bulk of PendingAuthSession nodes.
+type PendingAuthSessionUpsertBulk struct {
+ create *PendingAuthSessionCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertBulk) UpdateNewValues() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(pendingauthsession.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertBulk) Ignore() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PendingAuthSessionUpsertBulk) DoNothing() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreateBulk.OnConflict
+// documentation for more info.
+func (u *PendingAuthSessionUpsertBulk) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PendingAuthSessionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateUpdatedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsertBulk) SetSessionToken(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetSessionToken(v)
+ })
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateSessionToken() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateSessionToken()
+ })
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsertBulk) SetIntent(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetIntent(v)
+ })
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateIntent() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateIntent()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderType(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderType() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderKey(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderKey() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderSubject(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderSubject() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsertBulk) SetTargetUserID(v int64) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTargetUserID(v)
+ })
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateTargetUserID() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTargetUserID()
+ })
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsertBulk) ClearTargetUserID() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTargetUserID()
+ })
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsertBulk) SetRedirectTo(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRedirectTo(v)
+ })
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateRedirectTo() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRedirectTo()
+ })
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsertBulk) SetResolvedEmail(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetResolvedEmail(v)
+ })
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateResolvedEmail() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateResolvedEmail()
+ })
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsertBulk) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRegistrationPasswordHash(v)
+ })
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRegistrationPasswordHash()
+ })
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsertBulk) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpstreamIdentityClaims(v)
+ })
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpstreamIdentityClaims()
+ })
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsertBulk) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetLocalFlowState(v)
+ })
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateLocalFlowState() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateLocalFlowState()
+ })
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsertBulk) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetBrowserSessionKey(v)
+ })
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateBrowserSessionKey() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateBrowserSessionKey()
+ })
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeHash(v)
+ })
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeHash() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeHash()
+ })
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeExpiresAt(v)
+ })
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeExpiresAt()
+ })
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearCompletionCodeExpiresAt()
+ })
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetEmailVerifiedAt(v)
+ })
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateEmailVerifiedAt()
+ })
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearEmailVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearEmailVerifiedAt()
+ })
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetPasswordVerifiedAt(v)
+ })
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdatePasswordVerifiedAt()
+ })
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearPasswordVerifiedAt()
+ })
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTotpVerifiedAt(v)
+ })
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTotpVerifiedAt()
+ })
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearTotpVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTotpVerifiedAt()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetConsumedAt(v)
+ })
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateConsumedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateConsumedAt()
+ })
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearConsumedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearConsumedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PendingAuthSessionUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PendingAuthSessionCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PendingAuthSessionCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/pendingauthsession_delete.go b/backend/ent/pendingauthsession_delete.go
new file mode 100644
index 00000000..ee4fe605
--- /dev/null
+++ b/backend/ent/pendingauthsession_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// PendingAuthSessionDelete is the builder for deleting a PendingAuthSession entity.
+type PendingAuthSessionDelete struct {
+ config
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// Where appends a list predicates to the PendingAuthSessionDelete builder.
+func (_d *PendingAuthSessionDelete) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *PendingAuthSessionDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PendingAuthSessionDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *PendingAuthSessionDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// PendingAuthSessionDeleteOne is the builder for deleting a single PendingAuthSession entity.
+type PendingAuthSessionDeleteOne struct {
+ _d *PendingAuthSessionDelete
+}
+
+// Where appends a list predicates to the PendingAuthSessionDelete builder.
+func (_d *PendingAuthSessionDeleteOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *PendingAuthSessionDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{pendingauthsession.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PendingAuthSessionDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/pendingauthsession_query.go b/backend/ent/pendingauthsession_query.go
new file mode 100644
index 00000000..78e29cd2
--- /dev/null
+++ b/backend/ent/pendingauthsession_query.go
@@ -0,0 +1,717 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionQuery is the builder for querying PendingAuthSession entities.
+type PendingAuthSessionQuery struct {
+ config
+ ctx *QueryContext
+ order []pendingauthsession.OrderOption
+ inters []Interceptor
+ predicates []predicate.PendingAuthSession
+ withTargetUser *UserQuery
+ withAdoptionDecision *IdentityAdoptionDecisionQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the PendingAuthSessionQuery builder.
+func (_q *PendingAuthSessionQuery) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *PendingAuthSessionQuery) Limit(limit int) *PendingAuthSessionQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *PendingAuthSessionQuery) Offset(offset int) *PendingAuthSessionQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *PendingAuthSessionQuery) Unique(unique bool) *PendingAuthSessionQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *PendingAuthSessionQuery) Order(o ...pendingauthsession.OrderOption) *PendingAuthSessionQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryTargetUser chains the current query on the "target_user" edge.
+func (_q *PendingAuthSessionQuery) QueryTargetUser() *UserQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecision chains the current query on the "adoption_decision" edge.
+func (_q *PendingAuthSessionQuery) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first PendingAuthSession entity from the query.
+// Returns a *NotFoundError when no PendingAuthSession was found.
+func (_q *PendingAuthSessionQuery) First(ctx context.Context) (*PendingAuthSession, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{pendingauthsession.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) FirstX(ctx context.Context) *PendingAuthSession {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first PendingAuthSession ID from the query.
+// Returns a *NotFoundError when no PendingAuthSession ID was found.
+func (_q *PendingAuthSessionQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{pendingauthsession.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single PendingAuthSession entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one PendingAuthSession entity is found.
+// Returns a *NotFoundError when no PendingAuthSession entities are found.
+func (_q *PendingAuthSessionQuery) Only(ctx context.Context) (*PendingAuthSession, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{pendingauthsession.Label}
+ default:
+ return nil, &NotSingularError{pendingauthsession.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) OnlyX(ctx context.Context) *PendingAuthSession {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only PendingAuthSession ID in the query.
+// Returns a *NotSingularError when more than one PendingAuthSession ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *PendingAuthSessionQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{pendingauthsession.Label}
+ default:
+ err = &NotSingularError{pendingauthsession.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of PendingAuthSessions.
+func (_q *PendingAuthSessionQuery) All(ctx context.Context) ([]*PendingAuthSession, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*PendingAuthSession, *PendingAuthSessionQuery]()
+ return withInterceptors[[]*PendingAuthSession](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) AllX(ctx context.Context) []*PendingAuthSession {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of PendingAuthSession IDs.
+func (_q *PendingAuthSessionQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(pendingauthsession.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *PendingAuthSessionQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*PendingAuthSessionQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *PendingAuthSessionQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the PendingAuthSessionQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *PendingAuthSessionQuery) Clone() *PendingAuthSessionQuery {
+ if _q == nil {
+ return nil
+ }
+ return &PendingAuthSessionQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]pendingauthsession.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.PendingAuthSession{}, _q.predicates...),
+ withTargetUser: _q.withTargetUser.Clone(),
+ withAdoptionDecision: _q.withAdoptionDecision.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithTargetUser tells the query-builder to eager-load the nodes that are connected to
+// the "target_user" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *PendingAuthSessionQuery) WithTargetUser(opts ...func(*UserQuery)) *PendingAuthSessionQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withTargetUser = query
+ return _q
+}
+
+// WithAdoptionDecision tells the query-builder to eager-load the nodes that are connected to
+// the "adoption_decision" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *PendingAuthSessionQuery) WithAdoptionDecision(opts ...func(*IdentityAdoptionDecisionQuery)) *PendingAuthSessionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAdoptionDecision = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.PendingAuthSession.Query().
+// GroupBy(pendingauthsession.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *PendingAuthSessionQuery) GroupBy(field string, fields ...string) *PendingAuthSessionGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &PendingAuthSessionGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = pendingauthsession.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.PendingAuthSession.Query().
+// Select(pendingauthsession.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *PendingAuthSessionQuery) Select(fields ...string) *PendingAuthSessionSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &PendingAuthSessionSelect{PendingAuthSessionQuery: _q}
+ sbuild.label = pendingauthsession.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a PendingAuthSessionSelect configured with the given aggregations.
+func (_q *PendingAuthSessionQuery) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *PendingAuthSessionQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !pendingauthsession.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *PendingAuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PendingAuthSession, error) {
+ var (
+ nodes = []*PendingAuthSession{}
+ _spec = _q.querySpec()
+ loadedTypes = [2]bool{
+ _q.withTargetUser != nil,
+ _q.withAdoptionDecision != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*PendingAuthSession).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &PendingAuthSession{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withTargetUser; query != nil {
+ if err := _q.loadTargetUser(ctx, query, nodes, nil,
+ func(n *PendingAuthSession, e *User) { n.Edges.TargetUser = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withAdoptionDecision; query != nil {
+ if err := _q.loadAdoptionDecision(ctx, query, nodes, nil,
+ func(n *PendingAuthSession, e *IdentityAdoptionDecision) { n.Edges.AdoptionDecision = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *PendingAuthSessionQuery) loadTargetUser(ctx context.Context, query *UserQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *User)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*PendingAuthSession)
+ for i := range nodes {
+ if nodes[i].TargetUserID == nil {
+ continue
+ }
+ fk := *nodes[i].TargetUserID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(user.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "target_user_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *PendingAuthSessionQuery) loadAdoptionDecision(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *IdentityAdoptionDecision)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*PendingAuthSession)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(pendingauthsession.AdoptionDecisionColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.PendingAuthSessionID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "pending_auth_session_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *PendingAuthSessionQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *PendingAuthSessionQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID)
+ for i := range fields {
+ if fields[i] != pendingauthsession.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withTargetUser != nil {
+ _spec.Node.AddColumnOnce(pendingauthsession.FieldTargetUserID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *PendingAuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(pendingauthsession.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = pendingauthsession.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *PendingAuthSessionQuery) ForUpdate(opts ...sql.LockOption) *PendingAuthSessionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *PendingAuthSessionQuery) ForShare(opts ...sql.LockOption) *PendingAuthSessionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// PendingAuthSessionGroupBy is the group-by builder for PendingAuthSession entities.
+type PendingAuthSessionGroupBy struct {
+ selector
+ build *PendingAuthSessionQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *PendingAuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *PendingAuthSessionGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *PendingAuthSessionGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *PendingAuthSessionGroupBy) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// PendingAuthSessionSelect is the builder for selecting fields of PendingAuthSession entities.
+type PendingAuthSessionSelect struct {
+ *PendingAuthSessionQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *PendingAuthSessionSelect) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *PendingAuthSessionSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionSelect](ctx, _s.PendingAuthSessionQuery, _s, _s.inters, v)
+}
+
+func (_s *PendingAuthSessionSelect) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/pendingauthsession_update.go b/backend/ent/pendingauthsession_update.go
new file mode 100644
index 00000000..00066f69
--- /dev/null
+++ b/backend/ent/pendingauthsession_update.go
@@ -0,0 +1,1178 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionUpdate is the builder for updating PendingAuthSession entities.
+type PendingAuthSessionUpdate struct {
+ config
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// Where appends a list predicates to the PendingAuthSessionUpdate builder.
+func (_u *PendingAuthSessionUpdate) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PendingAuthSessionUpdate) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_u *PendingAuthSessionUpdate) SetSessionToken(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetSessionToken(v)
+ return _u
+}
+
+// SetNillableSessionToken sets the "session_token" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableSessionToken(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetSessionToken(*v)
+ }
+ return _u
+}
+
+// SetIntent sets the "intent" field.
+func (_u *PendingAuthSessionUpdate) SetIntent(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetIntent(v)
+ return _u
+}
+
+// SetNillableIntent sets the "intent" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableIntent(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetIntent(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *PendingAuthSessionUpdate) SetProviderType(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderType(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PendingAuthSessionUpdate) SetProviderKey(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderKey(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *PendingAuthSessionUpdate) SetProviderSubject(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_u *PendingAuthSessionUpdate) SetTargetUserID(v int64) *PendingAuthSessionUpdate {
+ _u.mutation.SetTargetUserID(v)
+ return _u
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetTargetUserID(*v)
+ }
+ return _u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (_u *PendingAuthSessionUpdate) ClearTargetUserID() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTargetUserID()
+ return _u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_u *PendingAuthSessionUpdate) SetRedirectTo(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetRedirectTo(v)
+ return _u
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetRedirectTo(*v)
+ }
+ return _u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_u *PendingAuthSessionUpdate) SetResolvedEmail(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetResolvedEmail(v)
+ return _u
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetResolvedEmail(*v)
+ }
+ return _u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_u *PendingAuthSessionUpdate) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetRegistrationPasswordHash(v)
+ return _u
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetRegistrationPasswordHash(*v)
+ }
+ return _u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_u *PendingAuthSessionUpdate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdate {
+ _u.mutation.SetUpstreamIdentityClaims(v)
+ return _u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_u *PendingAuthSessionUpdate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdate {
+ _u.mutation.SetLocalFlowState(v)
+ return _u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_u *PendingAuthSessionUpdate) SetBrowserSessionKey(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetBrowserSessionKey(v)
+ return _u
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetBrowserSessionKey(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_u *PendingAuthSessionUpdate) SetCompletionCodeHash(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetCompletionCodeHash(v)
+ return _u
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetCompletionCodeHash(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetCompletionCodeExpiresAt(v)
+ return _u
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetCompletionCodeExpiresAt(*v)
+ }
+ return _u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdate) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearCompletionCodeExpiresAt()
+ return _u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetEmailVerifiedAt(v)
+ return _u
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetEmailVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearEmailVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearEmailVerifiedAt()
+ return _u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetPasswordVerifiedAt(v)
+ return _u
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetPasswordVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearPasswordVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearPasswordVerifiedAt()
+ return _u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetTotpVerifiedAt(v)
+ return _u
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetTotpVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearTotpVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTotpVerifiedAt()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *PendingAuthSessionUpdate) SetExpiresAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_u *PendingAuthSessionUpdate) SetConsumedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetConsumedAt(v)
+ return _u
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetConsumedAt(*v)
+ }
+ return _u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (_u *PendingAuthSessionUpdate) ClearConsumedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearConsumedAt()
+ return _u
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdate) SetTargetUser(v *User) *PendingAuthSessionUpdate {
+ return _u.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_u *PendingAuthSessionUpdate) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdate {
+ _u.mutation.SetAdoptionDecisionID(id)
+ return _u
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdate {
+ if id != nil {
+ _u = _u.SetAdoptionDecisionID(*id)
+ }
+ return _u
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdate {
+ return _u.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_u *PendingAuthSessionUpdate) Mutation() *PendingAuthSessionMutation {
+ return _u.mutation
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdate) ClearTargetUser() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTargetUser()
+ return _u
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdate) ClearAdoptionDecision() *PendingAuthSessionUpdate {
+ _u.mutation.ClearAdoptionDecision()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *PendingAuthSessionUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *PendingAuthSessionUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PendingAuthSessionUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PendingAuthSessionUpdate) check() error {
+ if v, ok := _u.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PendingAuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ }
+ if _u.mutation.CompletionCodeExpiresAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.EmailVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.PasswordVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ }
+ if _u.mutation.ConsumedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime)
+ }
+ if _u.mutation.TargetUserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{pendingauthsession.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// PendingAuthSessionUpdateOne is the builder for updating a single PendingAuthSession entity.
+type PendingAuthSessionUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_u *PendingAuthSessionUpdateOne) SetSessionToken(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetSessionToken(v)
+ return _u
+}
+
+// SetNillableSessionToken sets the "session_token" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableSessionToken(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetSessionToken(*v)
+ }
+ return _u
+}
+
+// SetIntent sets the "intent" field.
+func (_u *PendingAuthSessionUpdateOne) SetIntent(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetIntent(v)
+ return _u
+}
+
+// SetNillableIntent sets the "intent" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableIntent(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetIntent(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderType(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderType(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderKey(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderKey(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderSubject(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_u *PendingAuthSessionUpdateOne) SetTargetUserID(v int64) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetTargetUserID(v)
+ return _u
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetTargetUserID(*v)
+ }
+ return _u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (_u *PendingAuthSessionUpdateOne) ClearTargetUserID() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTargetUserID()
+ return _u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_u *PendingAuthSessionUpdateOne) SetRedirectTo(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetRedirectTo(v)
+ return _u
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetRedirectTo(*v)
+ }
+ return _u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_u *PendingAuthSessionUpdateOne) SetResolvedEmail(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetResolvedEmail(v)
+ return _u
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetResolvedEmail(*v)
+ }
+ return _u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_u *PendingAuthSessionUpdateOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetRegistrationPasswordHash(v)
+ return _u
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetRegistrationPasswordHash(*v)
+ }
+ return _u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_u *PendingAuthSessionUpdateOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetUpstreamIdentityClaims(v)
+ return _u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_u *PendingAuthSessionUpdateOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetLocalFlowState(v)
+ return _u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_u *PendingAuthSessionUpdateOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetBrowserSessionKey(v)
+ return _u
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetBrowserSessionKey(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetCompletionCodeHash(v)
+ return _u
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetCompletionCodeHash(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetCompletionCodeExpiresAt(v)
+ return _u
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetCompletionCodeExpiresAt(*v)
+ }
+ return _u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearCompletionCodeExpiresAt()
+ return _u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetEmailVerifiedAt(v)
+ return _u
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetEmailVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearEmailVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearEmailVerifiedAt()
+ return _u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetPasswordVerifiedAt(v)
+ return _u
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetPasswordVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearPasswordVerifiedAt()
+ return _u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetTotpVerifiedAt(v)
+ return _u
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetTotpVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearTotpVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTotpVerifiedAt()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetConsumedAt(v)
+ return _u
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetConsumedAt(*v)
+ }
+ return _u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearConsumedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearConsumedAt()
+ return _u
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdateOne) SetTargetUser(v *User) *PendingAuthSessionUpdateOne {
+ return _u.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetAdoptionDecisionID(id)
+ return _u
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdateOne {
+ if id != nil {
+ _u = _u.SetAdoptionDecisionID(*id)
+ }
+ return _u
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdateOne {
+ return _u.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_u *PendingAuthSessionUpdateOne) Mutation() *PendingAuthSessionMutation {
+ return _u.mutation
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdateOne) ClearTargetUser() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTargetUser()
+ return _u
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdateOne) ClearAdoptionDecision() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearAdoptionDecision()
+ return _u
+}
+
+// Where appends a list predicates to the PendingAuthSessionUpdate builder.
+func (_u *PendingAuthSessionUpdateOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *PendingAuthSessionUpdateOne) Select(field string, fields ...string) *PendingAuthSessionUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated PendingAuthSession entity.
+func (_u *PendingAuthSessionUpdateOne) Save(ctx context.Context) (*PendingAuthSession, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdateOne) SaveX(ctx context.Context) *PendingAuthSession {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *PendingAuthSessionUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PendingAuthSessionUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PendingAuthSessionUpdateOne) check() error {
+ if v, ok := _u.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PendingAuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *PendingAuthSession, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PendingAuthSession.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID)
+ for _, f := range fields {
+ if !pendingauthsession.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != pendingauthsession.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ }
+ if _u.mutation.CompletionCodeExpiresAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.EmailVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.PasswordVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ }
+ if _u.mutation.ConsumedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime)
+ }
+ if _u.mutation.TargetUserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &PendingAuthSession{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{pendingauthsession.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go
index ef551940..dc86471e 100644
--- a/backend/ent/predicate/predicate.go
+++ b/backend/ent/predicate/predicate.go
@@ -21,6 +21,24 @@ type Announcement func(*sql.Selector)
// AnnouncementRead is the predicate function for announcementread builders.
type AnnouncementRead func(*sql.Selector)
+// AuthIdentity is the predicate function for authidentity builders.
+type AuthIdentity func(*sql.Selector)
+
+// AuthIdentityChannel is the predicate function for authidentitychannel builders.
+type AuthIdentityChannel func(*sql.Selector)
+
+// ChannelMonitor is the predicate function for channelmonitor builders.
+type ChannelMonitor func(*sql.Selector)
+
+// ChannelMonitorDailyRollup is the predicate function for channelmonitordailyrollup builders.
+type ChannelMonitorDailyRollup func(*sql.Selector)
+
+// ChannelMonitorHistory is the predicate function for channelmonitorhistory builders.
+type ChannelMonitorHistory func(*sql.Selector)
+
+// ChannelMonitorRequestTemplate is the predicate function for channelmonitorrequesttemplate builders.
+type ChannelMonitorRequestTemplate func(*sql.Selector)
+
// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders.
type ErrorPassthroughRule func(*sql.Selector)
@@ -30,6 +48,9 @@ type Group func(*sql.Selector)
// IdempotencyRecord is the predicate function for idempotencyrecord builders.
type IdempotencyRecord func(*sql.Selector)
+// IdentityAdoptionDecision is the predicate function for identityadoptiondecision builders.
+type IdentityAdoptionDecision func(*sql.Selector)
+
// PaymentAuditLog is the predicate function for paymentauditlog builders.
type PaymentAuditLog func(*sql.Selector)
@@ -39,6 +60,9 @@ type PaymentOrder func(*sql.Selector)
// PaymentProviderInstance is the predicate function for paymentproviderinstance builders.
type PaymentProviderInstance func(*sql.Selector)
+// PendingAuthSession is the predicate function for pendingauthsession builders.
+type PendingAuthSession func(*sql.Selector)
+
// PromoCode is the predicate function for promocode builders.
type PromoCode func(*sql.Selector)
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index fbdd08c7..6b344a55 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -10,12 +10,20 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -309,6 +317,366 @@ func init() {
announcementreadDescCreatedAt := announcementreadFields[3].Descriptor()
// announcementread.DefaultCreatedAt holds the default value on creation for the created_at field.
announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time)
+ authidentityMixin := schema.AuthIdentity{}.Mixin()
+ authidentityMixinFields0 := authidentityMixin[0].Fields()
+ _ = authidentityMixinFields0
+ authidentityFields := schema.AuthIdentity{}.Fields()
+ _ = authidentityFields
+ // authidentityDescCreatedAt is the schema descriptor for created_at field.
+ authidentityDescCreatedAt := authidentityMixinFields0[0].Descriptor()
+ // authidentity.DefaultCreatedAt holds the default value on creation for the created_at field.
+ authidentity.DefaultCreatedAt = authidentityDescCreatedAt.Default.(func() time.Time)
+ // authidentityDescUpdatedAt is the schema descriptor for updated_at field.
+ authidentityDescUpdatedAt := authidentityMixinFields0[1].Descriptor()
+ // authidentity.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ authidentity.DefaultUpdatedAt = authidentityDescUpdatedAt.Default.(func() time.Time)
+ // authidentity.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ authidentity.UpdateDefaultUpdatedAt = authidentityDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // authidentityDescProviderType is the schema descriptor for provider_type field.
+ authidentityDescProviderType := authidentityFields[1].Descriptor()
+ // authidentity.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ authidentity.ProviderTypeValidator = func() func(string) error {
+ validators := authidentityDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentityDescProviderKey is the schema descriptor for provider_key field.
+ authidentityDescProviderKey := authidentityFields[2].Descriptor()
+ // authidentity.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ authidentity.ProviderKeyValidator = authidentityDescProviderKey.Validators[0].(func(string) error)
+ // authidentityDescProviderSubject is the schema descriptor for provider_subject field.
+ authidentityDescProviderSubject := authidentityFields[3].Descriptor()
+ // authidentity.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ authidentity.ProviderSubjectValidator = authidentityDescProviderSubject.Validators[0].(func(string) error)
+ // authidentityDescMetadata is the schema descriptor for metadata field.
+ authidentityDescMetadata := authidentityFields[6].Descriptor()
+ // authidentity.DefaultMetadata holds the default value on creation for the metadata field.
+ authidentity.DefaultMetadata = authidentityDescMetadata.Default.(func() map[string]interface{})
+ authidentitychannelMixin := schema.AuthIdentityChannel{}.Mixin()
+ authidentitychannelMixinFields0 := authidentitychannelMixin[0].Fields()
+ _ = authidentitychannelMixinFields0
+ authidentitychannelFields := schema.AuthIdentityChannel{}.Fields()
+ _ = authidentitychannelFields
+ // authidentitychannelDescCreatedAt is the schema descriptor for created_at field.
+ authidentitychannelDescCreatedAt := authidentitychannelMixinFields0[0].Descriptor()
+ // authidentitychannel.DefaultCreatedAt holds the default value on creation for the created_at field.
+ authidentitychannel.DefaultCreatedAt = authidentitychannelDescCreatedAt.Default.(func() time.Time)
+ // authidentitychannelDescUpdatedAt is the schema descriptor for updated_at field.
+ authidentitychannelDescUpdatedAt := authidentitychannelMixinFields0[1].Descriptor()
+ // authidentitychannel.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ authidentitychannel.DefaultUpdatedAt = authidentitychannelDescUpdatedAt.Default.(func() time.Time)
+ // authidentitychannel.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ authidentitychannel.UpdateDefaultUpdatedAt = authidentitychannelDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // authidentitychannelDescProviderType is the schema descriptor for provider_type field.
+ authidentitychannelDescProviderType := authidentitychannelFields[1].Descriptor()
+ // authidentitychannel.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ authidentitychannel.ProviderTypeValidator = func() func(string) error {
+ validators := authidentitychannelDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentitychannelDescProviderKey is the schema descriptor for provider_key field.
+ authidentitychannelDescProviderKey := authidentitychannelFields[2].Descriptor()
+ // authidentitychannel.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ authidentitychannel.ProviderKeyValidator = authidentitychannelDescProviderKey.Validators[0].(func(string) error)
+ // authidentitychannelDescChannel is the schema descriptor for channel field.
+ authidentitychannelDescChannel := authidentitychannelFields[3].Descriptor()
+ // authidentitychannel.ChannelValidator is a validator for the "channel" field. It is called by the builders before save.
+ authidentitychannel.ChannelValidator = func() func(string) error {
+ validators := authidentitychannelDescChannel.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(channel string) error {
+ for _, fn := range fns {
+ if err := fn(channel); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentitychannelDescChannelAppID is the schema descriptor for channel_app_id field.
+ authidentitychannelDescChannelAppID := authidentitychannelFields[4].Descriptor()
+ // authidentitychannel.ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save.
+ authidentitychannel.ChannelAppIDValidator = authidentitychannelDescChannelAppID.Validators[0].(func(string) error)
+ // authidentitychannelDescChannelSubject is the schema descriptor for channel_subject field.
+ authidentitychannelDescChannelSubject := authidentitychannelFields[5].Descriptor()
+ // authidentitychannel.ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save.
+ authidentitychannel.ChannelSubjectValidator = authidentitychannelDescChannelSubject.Validators[0].(func(string) error)
+ // authidentitychannelDescMetadata is the schema descriptor for metadata field.
+ authidentitychannelDescMetadata := authidentitychannelFields[6].Descriptor()
+ // authidentitychannel.DefaultMetadata holds the default value on creation for the metadata field.
+ authidentitychannel.DefaultMetadata = authidentitychannelDescMetadata.Default.(func() map[string]interface{})
+ channelmonitorMixin := schema.ChannelMonitor{}.Mixin()
+ channelmonitorMixinFields0 := channelmonitorMixin[0].Fields()
+ _ = channelmonitorMixinFields0
+ channelmonitorFields := schema.ChannelMonitor{}.Fields()
+ _ = channelmonitorFields
+ // channelmonitorDescCreatedAt is the schema descriptor for created_at field.
+ channelmonitorDescCreatedAt := channelmonitorMixinFields0[0].Descriptor()
+ // channelmonitor.DefaultCreatedAt holds the default value on creation for the created_at field.
+ channelmonitor.DefaultCreatedAt = channelmonitorDescCreatedAt.Default.(func() time.Time)
+ // channelmonitorDescUpdatedAt is the schema descriptor for updated_at field.
+ channelmonitorDescUpdatedAt := channelmonitorMixinFields0[1].Descriptor()
+ // channelmonitor.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ channelmonitor.DefaultUpdatedAt = channelmonitorDescUpdatedAt.Default.(func() time.Time)
+ // channelmonitor.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ channelmonitor.UpdateDefaultUpdatedAt = channelmonitorDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // channelmonitorDescName is the schema descriptor for name field.
+ channelmonitorDescName := channelmonitorFields[0].Descriptor()
+ // channelmonitor.NameValidator is a validator for the "name" field. It is called by the builders before save.
+ channelmonitor.NameValidator = func() func(string) error {
+ validators := channelmonitorDescName.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(name string) error {
+ for _, fn := range fns {
+ if err := fn(name); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitorDescEndpoint is the schema descriptor for endpoint field.
+ channelmonitorDescEndpoint := channelmonitorFields[2].Descriptor()
+ // channelmonitor.EndpointValidator is a validator for the "endpoint" field. It is called by the builders before save.
+ channelmonitor.EndpointValidator = func() func(string) error {
+ validators := channelmonitorDescEndpoint.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(endpoint string) error {
+ for _, fn := range fns {
+ if err := fn(endpoint); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitorDescAPIKeyEncrypted is the schema descriptor for api_key_encrypted field.
+ channelmonitorDescAPIKeyEncrypted := channelmonitorFields[3].Descriptor()
+ // channelmonitor.APIKeyEncryptedValidator is a validator for the "api_key_encrypted" field. It is called by the builders before save.
+ channelmonitor.APIKeyEncryptedValidator = channelmonitorDescAPIKeyEncrypted.Validators[0].(func(string) error)
+ // channelmonitorDescPrimaryModel is the schema descriptor for primary_model field.
+ channelmonitorDescPrimaryModel := channelmonitorFields[4].Descriptor()
+ // channelmonitor.PrimaryModelValidator is a validator for the "primary_model" field. It is called by the builders before save.
+ channelmonitor.PrimaryModelValidator = func() func(string) error {
+ validators := channelmonitorDescPrimaryModel.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(primary_model string) error {
+ for _, fn := range fns {
+ if err := fn(primary_model); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitorDescExtraModels is the schema descriptor for extra_models field.
+ channelmonitorDescExtraModels := channelmonitorFields[5].Descriptor()
+ // channelmonitor.DefaultExtraModels holds the default value on creation for the extra_models field.
+ channelmonitor.DefaultExtraModels = channelmonitorDescExtraModels.Default.([]string)
+ // channelmonitorDescGroupName is the schema descriptor for group_name field.
+ channelmonitorDescGroupName := channelmonitorFields[6].Descriptor()
+ // channelmonitor.DefaultGroupName holds the default value on creation for the group_name field.
+ channelmonitor.DefaultGroupName = channelmonitorDescGroupName.Default.(string)
+ // channelmonitor.GroupNameValidator is a validator for the "group_name" field. It is called by the builders before save.
+ channelmonitor.GroupNameValidator = channelmonitorDescGroupName.Validators[0].(func(string) error)
+ // channelmonitorDescEnabled is the schema descriptor for enabled field.
+ channelmonitorDescEnabled := channelmonitorFields[7].Descriptor()
+ // channelmonitor.DefaultEnabled holds the default value on creation for the enabled field.
+ channelmonitor.DefaultEnabled = channelmonitorDescEnabled.Default.(bool)
+ // channelmonitorDescIntervalSeconds is the schema descriptor for interval_seconds field.
+ channelmonitorDescIntervalSeconds := channelmonitorFields[8].Descriptor()
+ // channelmonitor.IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save.
+ channelmonitor.IntervalSecondsValidator = channelmonitorDescIntervalSeconds.Validators[0].(func(int) error)
+ // channelmonitorDescExtraHeaders is the schema descriptor for extra_headers field.
+ channelmonitorDescExtraHeaders := channelmonitorFields[12].Descriptor()
+ // channelmonitor.DefaultExtraHeaders holds the default value on creation for the extra_headers field.
+ channelmonitor.DefaultExtraHeaders = channelmonitorDescExtraHeaders.Default.(map[string]string)
+ // channelmonitorDescBodyOverrideMode is the schema descriptor for body_override_mode field.
+ channelmonitorDescBodyOverrideMode := channelmonitorFields[13].Descriptor()
+ // channelmonitor.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field.
+ channelmonitor.DefaultBodyOverrideMode = channelmonitorDescBodyOverrideMode.Default.(string)
+ // channelmonitor.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
+ channelmonitor.BodyOverrideModeValidator = channelmonitorDescBodyOverrideMode.Validators[0].(func(string) error)
+ channelmonitordailyrollupFields := schema.ChannelMonitorDailyRollup{}.Fields()
+ _ = channelmonitordailyrollupFields
+ // channelmonitordailyrollupDescModel is the schema descriptor for model field.
+ channelmonitordailyrollupDescModel := channelmonitordailyrollupFields[1].Descriptor()
+ // channelmonitordailyrollup.ModelValidator is a validator for the "model" field. It is called by the builders before save.
+ channelmonitordailyrollup.ModelValidator = func() func(string) error {
+ validators := channelmonitordailyrollupDescModel.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(model string) error {
+ for _, fn := range fns {
+ if err := fn(model); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitordailyrollupDescTotalChecks is the schema descriptor for total_checks field.
+ channelmonitordailyrollupDescTotalChecks := channelmonitordailyrollupFields[3].Descriptor()
+ // channelmonitordailyrollup.DefaultTotalChecks holds the default value on creation for the total_checks field.
+ channelmonitordailyrollup.DefaultTotalChecks = channelmonitordailyrollupDescTotalChecks.Default.(int)
+ // channelmonitordailyrollupDescOkCount is the schema descriptor for ok_count field.
+ channelmonitordailyrollupDescOkCount := channelmonitordailyrollupFields[4].Descriptor()
+ // channelmonitordailyrollup.DefaultOkCount holds the default value on creation for the ok_count field.
+ channelmonitordailyrollup.DefaultOkCount = channelmonitordailyrollupDescOkCount.Default.(int)
+ // channelmonitordailyrollupDescOperationalCount is the schema descriptor for operational_count field.
+ channelmonitordailyrollupDescOperationalCount := channelmonitordailyrollupFields[5].Descriptor()
+ // channelmonitordailyrollup.DefaultOperationalCount holds the default value on creation for the operational_count field.
+ channelmonitordailyrollup.DefaultOperationalCount = channelmonitordailyrollupDescOperationalCount.Default.(int)
+ // channelmonitordailyrollupDescDegradedCount is the schema descriptor for degraded_count field.
+ channelmonitordailyrollupDescDegradedCount := channelmonitordailyrollupFields[6].Descriptor()
+ // channelmonitordailyrollup.DefaultDegradedCount holds the default value on creation for the degraded_count field.
+ channelmonitordailyrollup.DefaultDegradedCount = channelmonitordailyrollupDescDegradedCount.Default.(int)
+ // channelmonitordailyrollupDescFailedCount is the schema descriptor for failed_count field.
+ channelmonitordailyrollupDescFailedCount := channelmonitordailyrollupFields[7].Descriptor()
+ // channelmonitordailyrollup.DefaultFailedCount holds the default value on creation for the failed_count field.
+ channelmonitordailyrollup.DefaultFailedCount = channelmonitordailyrollupDescFailedCount.Default.(int)
+ // channelmonitordailyrollupDescErrorCount is the schema descriptor for error_count field.
+ channelmonitordailyrollupDescErrorCount := channelmonitordailyrollupFields[8].Descriptor()
+ // channelmonitordailyrollup.DefaultErrorCount holds the default value on creation for the error_count field.
+ channelmonitordailyrollup.DefaultErrorCount = channelmonitordailyrollupDescErrorCount.Default.(int)
+ // channelmonitordailyrollupDescSumLatencyMs is the schema descriptor for sum_latency_ms field.
+ channelmonitordailyrollupDescSumLatencyMs := channelmonitordailyrollupFields[9].Descriptor()
+ // channelmonitordailyrollup.DefaultSumLatencyMs holds the default value on creation for the sum_latency_ms field.
+ channelmonitordailyrollup.DefaultSumLatencyMs = channelmonitordailyrollupDescSumLatencyMs.Default.(int64)
+ // channelmonitordailyrollupDescCountLatency is the schema descriptor for count_latency field.
+ channelmonitordailyrollupDescCountLatency := channelmonitordailyrollupFields[10].Descriptor()
+ // channelmonitordailyrollup.DefaultCountLatency holds the default value on creation for the count_latency field.
+ channelmonitordailyrollup.DefaultCountLatency = channelmonitordailyrollupDescCountLatency.Default.(int)
+ // channelmonitordailyrollupDescSumPingLatencyMs is the schema descriptor for sum_ping_latency_ms field.
+ channelmonitordailyrollupDescSumPingLatencyMs := channelmonitordailyrollupFields[11].Descriptor()
+ // channelmonitordailyrollup.DefaultSumPingLatencyMs holds the default value on creation for the sum_ping_latency_ms field.
+ channelmonitordailyrollup.DefaultSumPingLatencyMs = channelmonitordailyrollupDescSumPingLatencyMs.Default.(int64)
+ // channelmonitordailyrollupDescCountPingLatency is the schema descriptor for count_ping_latency field.
+ channelmonitordailyrollupDescCountPingLatency := channelmonitordailyrollupFields[12].Descriptor()
+ // channelmonitordailyrollup.DefaultCountPingLatency holds the default value on creation for the count_ping_latency field.
+ channelmonitordailyrollup.DefaultCountPingLatency = channelmonitordailyrollupDescCountPingLatency.Default.(int)
+ // channelmonitordailyrollupDescComputedAt is the schema descriptor for computed_at field.
+ channelmonitordailyrollupDescComputedAt := channelmonitordailyrollupFields[13].Descriptor()
+ // channelmonitordailyrollup.DefaultComputedAt holds the default value on creation for the computed_at field.
+ channelmonitordailyrollup.DefaultComputedAt = channelmonitordailyrollupDescComputedAt.Default.(func() time.Time)
+ // channelmonitordailyrollup.UpdateDefaultComputedAt holds the default value on update for the computed_at field.
+ channelmonitordailyrollup.UpdateDefaultComputedAt = channelmonitordailyrollupDescComputedAt.UpdateDefault.(func() time.Time)
+ channelmonitorhistoryFields := schema.ChannelMonitorHistory{}.Fields()
+ _ = channelmonitorhistoryFields
+ // channelmonitorhistoryDescModel is the schema descriptor for model field.
+ channelmonitorhistoryDescModel := channelmonitorhistoryFields[1].Descriptor()
+ // channelmonitorhistory.ModelValidator is a validator for the "model" field. It is called by the builders before save.
+ channelmonitorhistory.ModelValidator = func() func(string) error {
+ validators := channelmonitorhistoryDescModel.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(model string) error {
+ for _, fn := range fns {
+ if err := fn(model); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitorhistoryDescMessage is the schema descriptor for message field.
+ channelmonitorhistoryDescMessage := channelmonitorhistoryFields[5].Descriptor()
+ // channelmonitorhistory.DefaultMessage holds the default value on creation for the message field.
+ channelmonitorhistory.DefaultMessage = channelmonitorhistoryDescMessage.Default.(string)
+ // channelmonitorhistory.MessageValidator is a validator for the "message" field. It is called by the builders before save.
+ channelmonitorhistory.MessageValidator = channelmonitorhistoryDescMessage.Validators[0].(func(string) error)
+ // channelmonitorhistoryDescCheckedAt is the schema descriptor for checked_at field.
+ channelmonitorhistoryDescCheckedAt := channelmonitorhistoryFields[6].Descriptor()
+ // channelmonitorhistory.DefaultCheckedAt holds the default value on creation for the checked_at field.
+ channelmonitorhistory.DefaultCheckedAt = channelmonitorhistoryDescCheckedAt.Default.(func() time.Time)
+ channelmonitorrequesttemplateMixin := schema.ChannelMonitorRequestTemplate{}.Mixin()
+ channelmonitorrequesttemplateMixinFields0 := channelmonitorrequesttemplateMixin[0].Fields()
+ _ = channelmonitorrequesttemplateMixinFields0
+ channelmonitorrequesttemplateFields := schema.ChannelMonitorRequestTemplate{}.Fields()
+ _ = channelmonitorrequesttemplateFields
+ // channelmonitorrequesttemplateDescCreatedAt is the schema descriptor for created_at field.
+ channelmonitorrequesttemplateDescCreatedAt := channelmonitorrequesttemplateMixinFields0[0].Descriptor()
+ // channelmonitorrequesttemplate.DefaultCreatedAt holds the default value on creation for the created_at field.
+ channelmonitorrequesttemplate.DefaultCreatedAt = channelmonitorrequesttemplateDescCreatedAt.Default.(func() time.Time)
+ // channelmonitorrequesttemplateDescUpdatedAt is the schema descriptor for updated_at field.
+ channelmonitorrequesttemplateDescUpdatedAt := channelmonitorrequesttemplateMixinFields0[1].Descriptor()
+ // channelmonitorrequesttemplate.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ channelmonitorrequesttemplate.DefaultUpdatedAt = channelmonitorrequesttemplateDescUpdatedAt.Default.(func() time.Time)
+ // channelmonitorrequesttemplate.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ channelmonitorrequesttemplate.UpdateDefaultUpdatedAt = channelmonitorrequesttemplateDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // channelmonitorrequesttemplateDescName is the schema descriptor for name field.
+ channelmonitorrequesttemplateDescName := channelmonitorrequesttemplateFields[0].Descriptor()
+ // channelmonitorrequesttemplate.NameValidator is a validator for the "name" field. It is called by the builders before save.
+ channelmonitorrequesttemplate.NameValidator = func() func(string) error {
+ validators := channelmonitorrequesttemplateDescName.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(name string) error {
+ for _, fn := range fns {
+ if err := fn(name); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitorrequesttemplateDescDescription is the schema descriptor for description field.
+ channelmonitorrequesttemplateDescDescription := channelmonitorrequesttemplateFields[2].Descriptor()
+ // channelmonitorrequesttemplate.DefaultDescription holds the default value on creation for the description field.
+ channelmonitorrequesttemplate.DefaultDescription = channelmonitorrequesttemplateDescDescription.Default.(string)
+ // channelmonitorrequesttemplate.DescriptionValidator is a validator for the "description" field. It is called by the builders before save.
+ channelmonitorrequesttemplate.DescriptionValidator = channelmonitorrequesttemplateDescDescription.Validators[0].(func(string) error)
+ // channelmonitorrequesttemplateDescExtraHeaders is the schema descriptor for extra_headers field.
+ channelmonitorrequesttemplateDescExtraHeaders := channelmonitorrequesttemplateFields[3].Descriptor()
+ // channelmonitorrequesttemplate.DefaultExtraHeaders holds the default value on creation for the extra_headers field.
+ channelmonitorrequesttemplate.DefaultExtraHeaders = channelmonitorrequesttemplateDescExtraHeaders.Default.(map[string]string)
+ // channelmonitorrequesttemplateDescBodyOverrideMode is the schema descriptor for body_override_mode field.
+ channelmonitorrequesttemplateDescBodyOverrideMode := channelmonitorrequesttemplateFields[4].Descriptor()
+ // channelmonitorrequesttemplate.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field.
+ channelmonitorrequesttemplate.DefaultBodyOverrideMode = channelmonitorrequesttemplateDescBodyOverrideMode.Default.(string)
+ // channelmonitorrequesttemplate.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
+ channelmonitorrequesttemplate.BodyOverrideModeValidator = channelmonitorrequesttemplateDescBodyOverrideMode.Validators[0].(func(string) error)
errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin()
errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields()
_ = errorpassthroughruleMixinFields0
@@ -477,6 +845,10 @@ func init() {
groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor()
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
+ // groupDescRpmLimit is the schema descriptor for rpm_limit field.
+ groupDescRpmLimit := groupFields[27].Descriptor()
+ // group.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
+ group.DefaultRpmLimit = groupDescRpmLimit.Default.(int)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
_ = idempotencyrecordMixinFields0
@@ -512,6 +884,33 @@ func init() {
idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor()
// idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save.
idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error)
+ identityadoptiondecisionMixin := schema.IdentityAdoptionDecision{}.Mixin()
+ identityadoptiondecisionMixinFields0 := identityadoptiondecisionMixin[0].Fields()
+ _ = identityadoptiondecisionMixinFields0
+ identityadoptiondecisionFields := schema.IdentityAdoptionDecision{}.Fields()
+ _ = identityadoptiondecisionFields
+ // identityadoptiondecisionDescCreatedAt is the schema descriptor for created_at field.
+ identityadoptiondecisionDescCreatedAt := identityadoptiondecisionMixinFields0[0].Descriptor()
+ // identityadoptiondecision.DefaultCreatedAt holds the default value on creation for the created_at field.
+ identityadoptiondecision.DefaultCreatedAt = identityadoptiondecisionDescCreatedAt.Default.(func() time.Time)
+ // identityadoptiondecisionDescUpdatedAt is the schema descriptor for updated_at field.
+ identityadoptiondecisionDescUpdatedAt := identityadoptiondecisionMixinFields0[1].Descriptor()
+ // identityadoptiondecision.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ identityadoptiondecision.DefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.Default.(func() time.Time)
+ // identityadoptiondecision.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ identityadoptiondecision.UpdateDefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // identityadoptiondecisionDescAdoptDisplayName is the schema descriptor for adopt_display_name field.
+ identityadoptiondecisionDescAdoptDisplayName := identityadoptiondecisionFields[2].Descriptor()
+ // identityadoptiondecision.DefaultAdoptDisplayName holds the default value on creation for the adopt_display_name field.
+ identityadoptiondecision.DefaultAdoptDisplayName = identityadoptiondecisionDescAdoptDisplayName.Default.(bool)
+ // identityadoptiondecisionDescAdoptAvatar is the schema descriptor for adopt_avatar field.
+ identityadoptiondecisionDescAdoptAvatar := identityadoptiondecisionFields[3].Descriptor()
+ // identityadoptiondecision.DefaultAdoptAvatar holds the default value on creation for the adopt_avatar field.
+ identityadoptiondecision.DefaultAdoptAvatar = identityadoptiondecisionDescAdoptAvatar.Default.(bool)
+ // identityadoptiondecisionDescDecidedAt is the schema descriptor for decided_at field.
+ identityadoptiondecisionDescDecidedAt := identityadoptiondecisionFields[4].Descriptor()
+ // identityadoptiondecision.DefaultDecidedAt holds the default value on creation for the decided_at field.
+ identityadoptiondecision.DefaultDecidedAt = identityadoptiondecisionDescDecidedAt.Default.(func() time.Time)
paymentauditlogFields := schema.PaymentAuditLog{}.Fields()
_ = paymentauditlogFields
// paymentauditlogDescOrderID is the schema descriptor for order_id field.
@@ -578,38 +977,42 @@ func init() {
paymentorderDescProviderInstanceID := paymentorderFields[18].Descriptor()
// paymentorder.ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save.
paymentorder.ProviderInstanceIDValidator = paymentorderDescProviderInstanceID.Validators[0].(func(string) error)
+ // paymentorderDescProviderKey is the schema descriptor for provider_key field.
+ paymentorderDescProviderKey := paymentorderFields[19].Descriptor()
+ // paymentorder.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ paymentorder.ProviderKeyValidator = paymentorderDescProviderKey.Validators[0].(func(string) error)
// paymentorderDescStatus is the schema descriptor for status field.
- paymentorderDescStatus := paymentorderFields[19].Descriptor()
+ paymentorderDescStatus := paymentorderFields[21].Descriptor()
// paymentorder.DefaultStatus holds the default value on creation for the status field.
paymentorder.DefaultStatus = paymentorderDescStatus.Default.(string)
// paymentorder.StatusValidator is a validator for the "status" field. It is called by the builders before save.
paymentorder.StatusValidator = paymentorderDescStatus.Validators[0].(func(string) error)
// paymentorderDescRefundAmount is the schema descriptor for refund_amount field.
- paymentorderDescRefundAmount := paymentorderFields[20].Descriptor()
+ paymentorderDescRefundAmount := paymentorderFields[22].Descriptor()
// paymentorder.DefaultRefundAmount holds the default value on creation for the refund_amount field.
paymentorder.DefaultRefundAmount = paymentorderDescRefundAmount.Default.(float64)
// paymentorderDescForceRefund is the schema descriptor for force_refund field.
- paymentorderDescForceRefund := paymentorderFields[23].Descriptor()
+ paymentorderDescForceRefund := paymentorderFields[25].Descriptor()
// paymentorder.DefaultForceRefund holds the default value on creation for the force_refund field.
paymentorder.DefaultForceRefund = paymentorderDescForceRefund.Default.(bool)
// paymentorderDescRefundRequestedBy is the schema descriptor for refund_requested_by field.
- paymentorderDescRefundRequestedBy := paymentorderFields[26].Descriptor()
+ paymentorderDescRefundRequestedBy := paymentorderFields[28].Descriptor()
// paymentorder.RefundRequestedByValidator is a validator for the "refund_requested_by" field. It is called by the builders before save.
paymentorder.RefundRequestedByValidator = paymentorderDescRefundRequestedBy.Validators[0].(func(string) error)
// paymentorderDescClientIP is the schema descriptor for client_ip field.
- paymentorderDescClientIP := paymentorderFields[32].Descriptor()
+ paymentorderDescClientIP := paymentorderFields[34].Descriptor()
// paymentorder.ClientIPValidator is a validator for the "client_ip" field. It is called by the builders before save.
paymentorder.ClientIPValidator = paymentorderDescClientIP.Validators[0].(func(string) error)
// paymentorderDescSrcHost is the schema descriptor for src_host field.
- paymentorderDescSrcHost := paymentorderFields[33].Descriptor()
+ paymentorderDescSrcHost := paymentorderFields[35].Descriptor()
// paymentorder.SrcHostValidator is a validator for the "src_host" field. It is called by the builders before save.
paymentorder.SrcHostValidator = paymentorderDescSrcHost.Validators[0].(func(string) error)
// paymentorderDescCreatedAt is the schema descriptor for created_at field.
- paymentorderDescCreatedAt := paymentorderFields[35].Descriptor()
+ paymentorderDescCreatedAt := paymentorderFields[37].Descriptor()
// paymentorder.DefaultCreatedAt holds the default value on creation for the created_at field.
paymentorder.DefaultCreatedAt = paymentorderDescCreatedAt.Default.(func() time.Time)
// paymentorderDescUpdatedAt is the schema descriptor for updated_at field.
- paymentorderDescUpdatedAt := paymentorderFields[36].Descriptor()
+ paymentorderDescUpdatedAt := paymentorderFields[38].Descriptor()
// paymentorder.DefaultUpdatedAt holds the default value on creation for the updated_at field.
paymentorder.DefaultUpdatedAt = paymentorderDescUpdatedAt.Default.(func() time.Time)
// paymentorder.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
@@ -682,6 +1085,113 @@ func init() {
paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time)
// paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
paymentproviderinstance.UpdateDefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.UpdateDefault.(func() time.Time)
+ pendingauthsessionMixin := schema.PendingAuthSession{}.Mixin()
+ pendingauthsessionMixinFields0 := pendingauthsessionMixin[0].Fields()
+ _ = pendingauthsessionMixinFields0
+ pendingauthsessionFields := schema.PendingAuthSession{}.Fields()
+ _ = pendingauthsessionFields
+ // pendingauthsessionDescCreatedAt is the schema descriptor for created_at field.
+ pendingauthsessionDescCreatedAt := pendingauthsessionMixinFields0[0].Descriptor()
+ // pendingauthsession.DefaultCreatedAt holds the default value on creation for the created_at field.
+ pendingauthsession.DefaultCreatedAt = pendingauthsessionDescCreatedAt.Default.(func() time.Time)
+ // pendingauthsessionDescUpdatedAt is the schema descriptor for updated_at field.
+ pendingauthsessionDescUpdatedAt := pendingauthsessionMixinFields0[1].Descriptor()
+ // pendingauthsession.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ pendingauthsession.DefaultUpdatedAt = pendingauthsessionDescUpdatedAt.Default.(func() time.Time)
+ // pendingauthsession.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ pendingauthsession.UpdateDefaultUpdatedAt = pendingauthsessionDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // pendingauthsessionDescSessionToken is the schema descriptor for session_token field.
+ pendingauthsessionDescSessionToken := pendingauthsessionFields[0].Descriptor()
+ // pendingauthsession.SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save.
+ pendingauthsession.SessionTokenValidator = func() func(string) error {
+ validators := pendingauthsessionDescSessionToken.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(session_token string) error {
+ for _, fn := range fns {
+ if err := fn(session_token); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescIntent is the schema descriptor for intent field.
+ pendingauthsessionDescIntent := pendingauthsessionFields[1].Descriptor()
+ // pendingauthsession.IntentValidator is a validator for the "intent" field. It is called by the builders before save.
+ pendingauthsession.IntentValidator = func() func(string) error {
+ validators := pendingauthsessionDescIntent.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(intent string) error {
+ for _, fn := range fns {
+ if err := fn(intent); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescProviderType is the schema descriptor for provider_type field.
+ pendingauthsessionDescProviderType := pendingauthsessionFields[2].Descriptor()
+ // pendingauthsession.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ pendingauthsession.ProviderTypeValidator = func() func(string) error {
+ validators := pendingauthsessionDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescProviderKey is the schema descriptor for provider_key field.
+ pendingauthsessionDescProviderKey := pendingauthsessionFields[3].Descriptor()
+ // pendingauthsession.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ pendingauthsession.ProviderKeyValidator = pendingauthsessionDescProviderKey.Validators[0].(func(string) error)
+ // pendingauthsessionDescProviderSubject is the schema descriptor for provider_subject field.
+ pendingauthsessionDescProviderSubject := pendingauthsessionFields[4].Descriptor()
+ // pendingauthsession.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ pendingauthsession.ProviderSubjectValidator = pendingauthsessionDescProviderSubject.Validators[0].(func(string) error)
+ // pendingauthsessionDescRedirectTo is the schema descriptor for redirect_to field.
+ pendingauthsessionDescRedirectTo := pendingauthsessionFields[6].Descriptor()
+ // pendingauthsession.DefaultRedirectTo holds the default value on creation for the redirect_to field.
+ pendingauthsession.DefaultRedirectTo = pendingauthsessionDescRedirectTo.Default.(string)
+ // pendingauthsessionDescResolvedEmail is the schema descriptor for resolved_email field.
+ pendingauthsessionDescResolvedEmail := pendingauthsessionFields[7].Descriptor()
+ // pendingauthsession.DefaultResolvedEmail holds the default value on creation for the resolved_email field.
+ pendingauthsession.DefaultResolvedEmail = pendingauthsessionDescResolvedEmail.Default.(string)
+ // pendingauthsessionDescRegistrationPasswordHash is the schema descriptor for registration_password_hash field.
+ pendingauthsessionDescRegistrationPasswordHash := pendingauthsessionFields[8].Descriptor()
+ // pendingauthsession.DefaultRegistrationPasswordHash holds the default value on creation for the registration_password_hash field.
+ pendingauthsession.DefaultRegistrationPasswordHash = pendingauthsessionDescRegistrationPasswordHash.Default.(string)
+ // pendingauthsessionDescUpstreamIdentityClaims is the schema descriptor for upstream_identity_claims field.
+ pendingauthsessionDescUpstreamIdentityClaims := pendingauthsessionFields[9].Descriptor()
+ // pendingauthsession.DefaultUpstreamIdentityClaims holds the default value on creation for the upstream_identity_claims field.
+ pendingauthsession.DefaultUpstreamIdentityClaims = pendingauthsessionDescUpstreamIdentityClaims.Default.(func() map[string]interface{})
+ // pendingauthsessionDescLocalFlowState is the schema descriptor for local_flow_state field.
+ pendingauthsessionDescLocalFlowState := pendingauthsessionFields[10].Descriptor()
+ // pendingauthsession.DefaultLocalFlowState holds the default value on creation for the local_flow_state field.
+ pendingauthsession.DefaultLocalFlowState = pendingauthsessionDescLocalFlowState.Default.(func() map[string]interface{})
+ // pendingauthsessionDescBrowserSessionKey is the schema descriptor for browser_session_key field.
+ pendingauthsessionDescBrowserSessionKey := pendingauthsessionFields[11].Descriptor()
+ // pendingauthsession.DefaultBrowserSessionKey holds the default value on creation for the browser_session_key field.
+ pendingauthsession.DefaultBrowserSessionKey = pendingauthsessionDescBrowserSessionKey.Default.(string)
+ // pendingauthsessionDescCompletionCodeHash is the schema descriptor for completion_code_hash field.
+ pendingauthsessionDescCompletionCodeHash := pendingauthsessionFields[12].Descriptor()
+ // pendingauthsession.DefaultCompletionCodeHash holds the default value on creation for the completion_code_hash field.
+ pendingauthsession.DefaultCompletionCodeHash = pendingauthsessionDescCompletionCodeHash.Default.(string)
promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields
// promocodeDescCode is the schema descriptor for code field.
@@ -1297,22 +1807,32 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
+ // userDescSignupSource is the schema descriptor for signup_source field.
+ userDescSignupSource := userFields[11].Descriptor()
+ // user.DefaultSignupSource holds the default value on creation for the signup_source field.
+ user.DefaultSignupSource = userDescSignupSource.Default.(string)
+ // user.SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
+ user.SignupSourceValidator = userDescSignupSource.Validators[0].(func(string) error)
// userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field.
- userDescBalanceNotifyEnabled := userFields[11].Descriptor()
+ userDescBalanceNotifyEnabled := userFields[14].Descriptor()
// user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field.
user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool)
// userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field.
- userDescBalanceNotifyThresholdType := userFields[12].Descriptor()
+ userDescBalanceNotifyThresholdType := userFields[15].Descriptor()
// user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field.
user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string)
// userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field.
- userDescBalanceNotifyExtraEmails := userFields[14].Descriptor()
+ userDescBalanceNotifyExtraEmails := userFields[17].Descriptor()
// user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field.
user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string)
// userDescTotalRecharged is the schema descriptor for total_recharged field.
- userDescTotalRecharged := userFields[15].Descriptor()
+ userDescTotalRecharged := userFields[18].Descriptor()
// user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
+ // userDescRpmLimit is the schema descriptor for rpm_limit field.
+ userDescRpmLimit := userFields[19].Descriptor()
+ // user.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
+ user.DefaultRpmLimit = userDescRpmLimit.Default.(int)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go
new file mode 100644
index 00000000..0b1b56ab
--- /dev/null
+++ b/backend/ent/schema/auth_identity.go
@@ -0,0 +1,94 @@
+package schema
+
+import (
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+var authProviderTypes = map[string]struct{}{
+ "email": {},
+ "linuxdo": {},
+ "oidc": {},
+ "wechat": {},
+}
+
+func validateAuthProviderType(value string) error {
+ if _, ok := authProviderTypes[value]; ok {
+ return nil
+ }
+ return fmt.Errorf("invalid auth provider type %q", value)
+}
+
+// AuthIdentity stores the canonical login identity for an account.
+type AuthIdentity struct {
+ ent.Schema
+}
+
+func (AuthIdentity) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "auth_identities"},
+ }
+}
+
+func (AuthIdentity) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (AuthIdentity) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("user_id"),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("provider_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Time("verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.String("issuer").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("metadata", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ }
+}
+
+func (AuthIdentity) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("user", User.Type).
+ Ref("auth_identities").
+ Field("user_id").
+ Required().
+ Unique(),
+ edge.To("channels", AuthIdentityChannel.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)),
+ edge.To("adoption_decisions", IdentityAdoptionDecision.Type),
+ }
+}
+
+func (AuthIdentity) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("provider_type", "provider_key", "provider_subject").Unique(),
+ index.Fields("user_id"),
+ index.Fields("user_id", "provider_type"),
+ }
+}
diff --git a/backend/ent/schema/auth_identity_channel.go b/backend/ent/schema/auth_identity_channel.go
new file mode 100644
index 00000000..69f2ad02
--- /dev/null
+++ b/backend/ent/schema/auth_identity_channel.go
@@ -0,0 +1,72 @@
+package schema
+
+import (
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// AuthIdentityChannel stores channel-scoped identifiers for a canonical identity.
+type AuthIdentityChannel struct {
+ ent.Schema
+}
+
+func (AuthIdentityChannel) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "auth_identity_channels"},
+ }
+}
+
+func (AuthIdentityChannel) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (AuthIdentityChannel) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("identity_id"),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("channel").
+ MaxLen(20).
+ NotEmpty(),
+ field.String("channel_app_id").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("channel_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("metadata", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ }
+}
+
+func (AuthIdentityChannel) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("identity", AuthIdentity.Type).
+ Ref("channels").
+ Field("identity_id").
+ Required().
+ Unique(),
+ }
+}
+
+func (AuthIdentityChannel) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("provider_type", "provider_key", "channel", "channel_app_id", "channel_subject").Unique(),
+ index.Fields("identity_id"),
+ }
+}
diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go
new file mode 100644
index 00000000..fbb93236
--- /dev/null
+++ b/backend/ent/schema/auth_identity_schema_test.go
@@ -0,0 +1,168 @@
+package schema
+
+import (
+ "testing"
+
+ "entgo.io/ent"
+ "entgo.io/ent/entc/load"
+ "entgo.io/ent/schema/field"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityFoundationSchemas(t *testing.T) {
+ spec, err := (&load.Config{Path: "."}).Load()
+ require.NoError(t, err)
+
+ schemas := map[string]*load.Schema{}
+ for _, schema := range spec.Schemas {
+ schemas[schema.Name] = schema
+ }
+
+ authIdentity := requireSchema(t, schemas, "AuthIdentity")
+ requireSchemaFields(t, authIdentity,
+ "user_id",
+ "provider_type",
+ "provider_key",
+ "provider_subject",
+ "verified_at",
+ "issuer",
+ "metadata",
+ )
+ requireHasUniqueIndex(t, authIdentity, "provider_type", "provider_key", "provider_subject")
+
+ authIdentityChannel := requireSchema(t, schemas, "AuthIdentityChannel")
+ requireSchemaFields(t, authIdentityChannel,
+ "identity_id",
+ "provider_type",
+ "provider_key",
+ "channel",
+ "channel_app_id",
+ "channel_subject",
+ "metadata",
+ )
+ requireHasUniqueIndex(t, authIdentityChannel, "provider_type", "provider_key", "channel", "channel_app_id", "channel_subject")
+
+ pendingAuthSession := requireSchema(t, schemas, "PendingAuthSession")
+ requireSchemaFields(t, pendingAuthSession,
+ "intent",
+ "provider_type",
+ "provider_key",
+ "provider_subject",
+ "target_user_id",
+ "redirect_to",
+ "resolved_email",
+ "registration_password_hash",
+ "upstream_identity_claims",
+ "local_flow_state",
+ "browser_session_key",
+ "completion_code_hash",
+ "completion_code_expires_at",
+ "email_verified_at",
+ "password_verified_at",
+ "totp_verified_at",
+ "expires_at",
+ "consumed_at",
+ )
+
+ adoptionDecision := requireSchema(t, schemas, "IdentityAdoptionDecision")
+ requireSchemaFields(t, adoptionDecision,
+ "pending_auth_session_id",
+ "identity_id",
+ "adopt_display_name",
+ "adopt_avatar",
+ "decided_at",
+ )
+ requireHasUniqueIndex(t, adoptionDecision, "pending_auth_session_id")
+
+ userSchema := requireSchema(t, schemas, "User")
+ requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at")
+ signupSource := requireSchemaField(t, userSchema, "signup_source")
+ require.Equal(t, field.TypeString, signupSource.Info.Type)
+ require.True(t, signupSource.Default)
+ require.Equal(t, "email", signupSource.DefaultValue)
+ require.Equal(t, 1, signupSource.Validators)
+
+ validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
+ for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} {
+ require.NoError(t, validator(value))
+ }
+ require.Error(t, validator("github"))
+}
+
+func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
+ t.Helper()
+
+ schema, ok := schemas[name]
+ require.True(t, ok, "schema %s should exist", name)
+ return schema
+}
+
+func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) {
+ t.Helper()
+
+ fields := map[string]struct{}{}
+ for _, field := range schema.Fields {
+ fields[field.Name] = struct{}{}
+ }
+
+ for _, name := range names {
+ _, ok := fields[name]
+ require.True(t, ok, "schema %s should include field %s", schema.Name, name)
+ }
+}
+
+func requireSchemaField(t *testing.T, schema *load.Schema, name string) *load.Field {
+ t.Helper()
+
+ for _, schemaField := range schema.Fields {
+ if schemaField.Name == name {
+ return schemaField
+ }
+ }
+
+ require.Failf(t, "missing schema field", "schema %s should include field %s", schema.Name, name)
+ return nil
+}
+
+func requireStringFieldValidator(t *testing.T, fields []ent.Field, name string) func(string) error {
+ t.Helper()
+
+ for _, entField := range fields {
+ descriptor := entField.Descriptor()
+ if descriptor.Name != name {
+ continue
+ }
+ require.NotEmpty(t, descriptor.Validators, "field %s should include a validator", name)
+ validator, ok := descriptor.Validators[0].(func(string) error)
+ require.True(t, ok, "field %s validator should be func(string) error", name)
+ return validator
+ }
+
+ require.Failf(t, "missing field validator", "schema should include field %s", name)
+ return nil
+}
+
+func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) {
+ t.Helper()
+
+ for _, index := range schema.Indexes {
+ if !index.Unique {
+ continue
+ }
+ if len(index.Fields) != len(fields) {
+ continue
+ }
+ match := true
+ for i := range fields {
+ if index.Fields[i] != fields[i] {
+ match = false
+ break
+ }
+ }
+ if match {
+ return
+ }
+ }
+
+ require.Failf(t, "missing unique index", "schema %s should include unique index on %v", schema.Name, fields)
+}
diff --git a/backend/ent/schema/channel_monitor.go b/backend/ent/schema/channel_monitor.go
new file mode 100644
index 00000000..355ade4b
--- /dev/null
+++ b/backend/ent/schema/channel_monitor.go
@@ -0,0 +1,110 @@
+package schema
+
+import (
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// ChannelMonitor holds the schema definition for the ChannelMonitor entity.
+// 渠道监控配置:定期对指定 provider/endpoint/api_key 下的模型做心跳测试。
+type ChannelMonitor struct {
+ ent.Schema
+}
+
+func (ChannelMonitor) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "channel_monitors"},
+ }
+}
+
+func (ChannelMonitor) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (ChannelMonitor) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("name").
+ NotEmpty().
+ MaxLen(100),
+ field.Enum("provider").
+ Values("openai", "anthropic", "gemini"),
+ field.String("endpoint").
+ NotEmpty().
+ MaxLen(500).
+ Comment("Provider base origin, e.g. https://api.openai.com"),
+ field.String("api_key_encrypted").
+ NotEmpty().
+ Sensitive().
+ Comment("AES-256-GCM encrypted API key"),
+ field.String("primary_model").
+ NotEmpty().
+ MaxLen(200),
+ field.JSON("extra_models", []string{}).
+ Default([]string{}).
+ Comment("Additional model names to test alongside primary_model"),
+ field.String("group_name").
+ Optional().
+ Default("").
+ MaxLen(100),
+ field.Bool("enabled").
+ Default(true),
+ field.Int("interval_seconds").
+ Range(15, 3600),
+ field.Time("last_checked_at").
+ Optional().
+ Nillable(),
+ field.Int64("created_by"),
+
+ // ---- 自定义请求快照字段(来自模板 / 手动编辑) ----
+
+ // template_id: 关联的请求模板 ID(仅用于 UI 分组 + 一键应用)。
+ // 实际运行时 checker 只读下面 3 个快照字段,**不再回查模板表**。
+ // 模板被删除时此字段会被 SET NULL(见 Edges 的 OnDelete 注解)。
+ field.Int64("template_id").
+ Optional().
+ Nillable(),
+ // extra_headers: 自定义 HTTP 头快照(来自模板 or 用户手填)。
+ // 运行时 merge 进 adapter 默认 headers。
+ field.JSON("extra_headers", map[string]string{}).
+ Default(map[string]string{}),
+ // body_override_mode: 同 ChannelMonitorRequestTemplate.body_override_mode
+ field.String("body_override_mode").
+ Default("off").
+ MaxLen(10),
+ // body_override: 同 ChannelMonitorRequestTemplate.body_override
+ field.JSON("body_override", map[string]any{}).
+ Optional(),
+ }
+}
+
+func (ChannelMonitor) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.To("history", ChannelMonitorHistory.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)),
+ edge.To("daily_rollups", ChannelMonitorDailyRollup.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)),
+ // 关联请求模板:模板被删除时 template_id 自动置空,
+ // 监控本身保留(继续用快照字段跑)。
+ edge.To("request_template", ChannelMonitorRequestTemplate.Type).
+ Field("template_id").
+ Unique().
+ Annotations(entsql.OnDelete(entsql.SetNull)),
+ }
+}
+
+func (ChannelMonitor) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("enabled", "last_checked_at"),
+ index.Fields("provider"),
+ index.Fields("group_name"),
+ index.Fields("template_id"),
+ }
+}
diff --git a/backend/ent/schema/channel_monitor_daily_rollup.go b/backend/ent/schema/channel_monitor_daily_rollup.go
new file mode 100644
index 00000000..23f032e3
--- /dev/null
+++ b/backend/ent/schema/channel_monitor_daily_rollup.go
@@ -0,0 +1,66 @@
+package schema
+
+import (
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// ChannelMonitorDailyRollup 按 (monitor_id, model, bucket_date) 维度聚合的渠道监控日统计。
+// 每天的明细被收敛为一行(保留 status 分布 + 延迟和),用于 7d/15d/30d 窗口的可用率
+// 加权计算(avg_latency = sum_latency_ms / count_latency;availability = ok_count / total_checks)。
+// 超过保留期由每日维护任务分批物理删(不用软删除,理由同 channel_monitor_history)。
+type ChannelMonitorDailyRollup struct {
+ ent.Schema
+}
+
+func (ChannelMonitorDailyRollup) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "channel_monitor_daily_rollups"},
+ }
+}
+
+func (ChannelMonitorDailyRollup) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("monitor_id"),
+ field.String("model").
+ NotEmpty().
+ MaxLen(200),
+ field.Time("bucket_date").
+ SchemaType(map[string]string{dialect.Postgres: "date"}),
+ field.Int("total_checks").Default(0),
+ field.Int("ok_count").Default(0),
+ field.Int("operational_count").Default(0),
+ field.Int("degraded_count").Default(0),
+ field.Int("failed_count").Default(0),
+ field.Int("error_count").Default(0),
+ field.Int64("sum_latency_ms").Default(0),
+ field.Int("count_latency").Default(0),
+ field.Int64("sum_ping_latency_ms").Default(0),
+ field.Int("count_ping_latency").Default(0),
+ field.Time("computed_at").Default(time.Now).UpdateDefault(time.Now),
+ }
+}
+
+func (ChannelMonitorDailyRollup) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("monitor", ChannelMonitor.Type).
+ Ref("daily_rollups").
+ Field("monitor_id").
+ Unique().
+ Required(),
+ }
+}
+
+func (ChannelMonitorDailyRollup) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("monitor_id", "model", "bucket_date").Unique(),
+ index.Fields("bucket_date"),
+ }
+}
diff --git a/backend/ent/schema/channel_monitor_history.go b/backend/ent/schema/channel_monitor_history.go
new file mode 100644
index 00000000..4366e79a
--- /dev/null
+++ b/backend/ent/schema/channel_monitor_history.go
@@ -0,0 +1,66 @@
+package schema
+
+import (
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// ChannelMonitorHistory holds the schema definition for the ChannelMonitorHistory entity.
+// 渠道监控历史:每次检测每个模型一行记录。明细只保留 1 天,超过 1 天由每日维护任务
+// 先聚合到 channel_monitor_daily_rollups,再分批物理删(不用软删除:日志类表无恢复
+// 需求,软删会让行和索引只增不减,徒增磁盘和查询开销)。
+type ChannelMonitorHistory struct {
+ ent.Schema
+}
+
+func (ChannelMonitorHistory) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "channel_monitor_histories"},
+ }
+}
+
+func (ChannelMonitorHistory) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("monitor_id"),
+ field.String("model").
+ NotEmpty().
+ MaxLen(200),
+ field.Enum("status").
+ Values("operational", "degraded", "failed", "error"),
+ field.Int("latency_ms").
+ Optional().
+ Nillable(),
+ field.Int("ping_latency_ms").
+ Optional().
+ Nillable(),
+ field.String("message").
+ Optional().
+ Default("").
+ MaxLen(500),
+ field.Time("checked_at").
+ Default(time.Now),
+ }
+}
+
+func (ChannelMonitorHistory) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("monitor", ChannelMonitor.Type).
+ Ref("history").
+ Field("monitor_id").
+ Unique().
+ Required(),
+ }
+}
+
+func (ChannelMonitorHistory) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("monitor_id", "model", "checked_at"),
+ index.Fields("checked_at"),
+ }
+}
diff --git a/backend/ent/schema/channel_monitor_request_template.go b/backend/ent/schema/channel_monitor_request_template.go
new file mode 100644
index 00000000..59df2f29
--- /dev/null
+++ b/backend/ent/schema/channel_monitor_request_template.go
@@ -0,0 +1,80 @@
+package schema
+
+import (
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// ChannelMonitorRequestTemplate 请求模板:一组可复用的 headers + 可选 body 覆盖配置。
+//
+// 语义为快照:模板被"应用"到监控时,extra_headers / body_override_mode / body_override
+// 会被**拷贝**到 channel_monitors 同名字段;后续模板变动不会自动影响已应用的监控——
+// 必须用户主动在模板编辑 Dialog 里点「应用到关联监控」才会覆盖快照。
+// 这样模板改错不会瞬间打挂所有已经跑起来的监控。
+type ChannelMonitorRequestTemplate struct {
+ ent.Schema
+}
+
+func (ChannelMonitorRequestTemplate) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "channel_monitor_request_templates"},
+ }
+}
+
+func (ChannelMonitorRequestTemplate) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (ChannelMonitorRequestTemplate) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("name").
+ NotEmpty().
+ MaxLen(100),
+ field.Enum("provider").
+ Values("openai", "anthropic", "gemini"),
+ field.String("description").
+ Optional().
+ Default("").
+ MaxLen(500),
+ // extra_headers: 用户自定义 HTTP 头(如 User-Agent 伪装)。
+ // 运行时 merge 进 adapter 默认 headers,用户值优先;
+ // hop-by-hop 黑名单(Host/Content-Length/...)由 checker 过滤。
+ field.JSON("extra_headers", map[string]string{}).
+ Default(map[string]string{}),
+ // body_override_mode: 'off' | 'merge' | 'replace'
+ // off - 用 adapter 默认 body(忽略 body_override)
+ // merge - adapter 默认 body 与 body_override 浅合并(body_override 优先,
+ // model/messages/contents 等关键字段在 checker 里走黑名单跳过)
+ // replace - 直接用 body_override 作为完整 body;此时跳过 challenge 校验,
+ // 改为 HTTP 2xx + 响应文本非空即视为可用
+ field.String("body_override_mode").
+ Default("off").
+ MaxLen(10),
+ // body_override: JSON 对象,根据 body_override_mode 使用。
+ // 用 map[string]any 以便前端传任意结构(含嵌套)。
+ field.JSON("body_override", map[string]any{}).
+ Optional(),
+ }
+}
+
+func (ChannelMonitorRequestTemplate) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("monitors", ChannelMonitor.Type).
+ Ref("request_template"),
+ }
+}
+
+func (ChannelMonitorRequestTemplate) Indexes() []ent.Index {
+ return []ent.Index{
+ // 同一 provider 内 name 唯一:允许 Anthropic + OpenAI 重名 "伪装官方客户端"。
+ index.Fields("provider", "name").Unique(),
+ }
+}
diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go
index d78a6898..11f38d66 100644
--- a/backend/ent/schema/group.go
+++ b/backend/ent/schema/group.go
@@ -145,6 +145,11 @@ func (Group) Fields() []ent.Field {
Default(domain.OpenAIMessagesDispatchModelConfig{}).
SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"),
+
+ // 分组级每分钟请求数上限(0 = 不限制)。设置后优先于用户级兜底生效。
+ field.Int("rpm_limit").
+ Default(0).
+ Comment("分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流"),
}
}
diff --git a/backend/ent/schema/identity_adoption_decision.go b/backend/ent/schema/identity_adoption_decision.go
new file mode 100644
index 00000000..9fdd26fb
--- /dev/null
+++ b/backend/ent/schema/identity_adoption_decision.go
@@ -0,0 +1,70 @@
+package schema
+
+import (
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// IdentityAdoptionDecision stores the one-time profile adoption choice captured during a pending auth flow.
+type IdentityAdoptionDecision struct {
+ ent.Schema
+}
+
+func (IdentityAdoptionDecision) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "identity_adoption_decisions"},
+ }
+}
+
+func (IdentityAdoptionDecision) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (IdentityAdoptionDecision) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("pending_auth_session_id"),
+ field.Int64("identity_id").
+ Optional().
+ Nillable(),
+ field.Bool("adopt_display_name").
+ Default(false),
+ field.Bool("adopt_avatar").
+ Default(false),
+ field.Time("decided_at").
+ Immutable().
+ Default(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (IdentityAdoptionDecision) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("pending_auth_session", PendingAuthSession.Type).
+ Ref("adoption_decision").
+ Field("pending_auth_session_id").
+ Required().
+ Unique(),
+ edge.From("identity", AuthIdentity.Type).
+ Ref("adoption_decisions").
+ Field("identity_id").
+ Unique(),
+ }
+}
+
+func (IdentityAdoptionDecision) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("pending_auth_session_id").Unique(),
+ index.Fields("identity_id"),
+ }
+}
diff --git a/backend/ent/schema/payment_order.go b/backend/ent/schema/payment_order.go
index a9576d2a..d25d1e5e 100644
--- a/backend/ent/schema/payment_order.go
+++ b/backend/ent/schema/payment_order.go
@@ -91,6 +91,13 @@ func (PaymentOrder) Fields() []ent.Field {
Optional().
Nillable().
MaxLen(64),
+ field.String("provider_key").
+ Optional().
+ Nillable().
+ MaxLen(30),
+ field.JSON("provider_snapshot", map[string]any{}).
+ Optional().
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// 状态
field.String("status").
@@ -178,7 +185,9 @@ func (PaymentOrder) Edges() []ent.Edge {
func (PaymentOrder) Indexes() []ent.Index {
return []ent.Index{
- index.Fields("out_trade_no"),
+ index.Fields("out_trade_no").
+ Unique().
+ Annotations(entsql.IndexWhere("out_trade_no <> ''")),
index.Fields("user_id"),
index.Fields("status"),
index.Fields("expires_at"),
diff --git a/backend/ent/schema/pending_auth_session.go b/backend/ent/schema/pending_auth_session.go
new file mode 100644
index 00000000..7e95f085
--- /dev/null
+++ b/backend/ent/schema/pending_auth_session.go
@@ -0,0 +1,135 @@
+package schema
+
+import (
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+var pendingAuthIntents = map[string]struct{}{
+ "login": {},
+ "bind_current_user": {},
+ "adopt_existing_user_by_email": {},
+}
+
+func validatePendingAuthIntent(value string) error {
+ if _, ok := pendingAuthIntents[value]; ok {
+ return nil
+ }
+ return fmt.Errorf("invalid pending auth intent %q", value)
+}
+
+// PendingAuthSession stores a short-lived post-auth decision session.
+type PendingAuthSession struct {
+ ent.Schema
+}
+
+func (PendingAuthSession) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "pending_auth_sessions"},
+ }
+}
+
+func (PendingAuthSession) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (PendingAuthSession) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("session_token").
+ MaxLen(255).
+ NotEmpty(),
+ field.String("intent").
+ MaxLen(40).
+ NotEmpty().
+ Validate(validatePendingAuthIntent),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("provider_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Int64("target_user_id").
+ Optional().
+ Nillable(),
+ field.String("redirect_to").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("resolved_email").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("registration_password_hash").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("upstream_identity_claims", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ field.JSON("local_flow_state", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ field.String("browser_session_key").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("completion_code_hash").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Time("completion_code_expires_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("email_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("password_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("totp_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("expires_at").
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("consumed_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (PendingAuthSession) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("target_user", User.Type).
+ Ref("pending_auth_sessions").
+ Field("target_user_id").
+ Unique(),
+ edge.To("adoption_decision", IdentityAdoptionDecision.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)).
+ Unique(),
+ }
+}
+
+func (PendingAuthSession) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("session_token").Unique(),
+ index.Fields("target_user_id"),
+ index.Fields("expires_at"),
+ index.Fields("provider_type", "provider_key", "provider_subject"),
+ index.Fields("completion_code_hash"),
+ }
+}
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index ef52e985..83da5c32 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -1,6 +1,8 @@
package schema
import (
+ "fmt"
+
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/domain"
@@ -72,6 +74,24 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
+ field.String("signup_source").
+ Validate(func(value string) error {
+ switch value {
+ case "email", "linuxdo", "wechat", "oidc":
+ return nil
+ default:
+ return fmt.Errorf("must be one of email, linuxdo, wechat, oidc")
+ }
+ }).
+ Default("email"),
+ field.Time("last_login_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("last_active_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
// 余额不足通知
field.Bool("balance_notify_enabled").
@@ -88,6 +108,10 @@ func (User) Fields() []ent.Field {
field.Float("total_recharged").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0),
+
+ // 用户级每分钟请求数上限(0 = 不限制)。仅当所在分组未设置 rpm_limit 时作为兜底生效。
+ field.Int("rpm_limit").
+ Default(0),
}
}
@@ -104,6 +128,9 @@ func (User) Edges() []ent.Edge {
edge.To("attribute_values", UserAttributeValue.Type),
edge.To("promo_code_usages", PromoCodeUsage.Type),
edge.To("payment_orders", PaymentOrder.Type),
+ edge.To("auth_identities", AuthIdentity.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)),
+ edge.To("pending_auth_sessions", PendingAuthSession.Type),
}
}
diff --git a/backend/ent/tx.go b/backend/ent/tx.go
index bb3139d5..611028e9 100644
--- a/backend/ent/tx.go
+++ b/backend/ent/tx.go
@@ -24,18 +24,34 @@ type Tx struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
+ // AuthIdentity is the client for interacting with the AuthIdentity builders.
+ AuthIdentity *AuthIdentityClient
+ // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
+ AuthIdentityChannel *AuthIdentityChannelClient
+ // ChannelMonitor is the client for interacting with the ChannelMonitor builders.
+ ChannelMonitor *ChannelMonitorClient
+ // ChannelMonitorDailyRollup is the client for interacting with the ChannelMonitorDailyRollup builders.
+ ChannelMonitorDailyRollup *ChannelMonitorDailyRollupClient
+ // ChannelMonitorHistory is the client for interacting with the ChannelMonitorHistory builders.
+ ChannelMonitorHistory *ChannelMonitorHistoryClient
+ // ChannelMonitorRequestTemplate is the client for interacting with the ChannelMonitorRequestTemplate builders.
+ ChannelMonitorRequestTemplate *ChannelMonitorRequestTemplateClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
+ // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
+ IdentityAdoptionDecision *IdentityAdoptionDecisionClient
// PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
PaymentAuditLog *PaymentAuditLogClient
// PaymentOrder is the client for interacting with the PaymentOrder builders.
PaymentOrder *PaymentOrderClient
// PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
PaymentProviderInstance *PaymentProviderInstanceClient
+ // PendingAuthSession is the client for interacting with the PendingAuthSession builders.
+ PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@@ -202,12 +218,20 @@ func (tx *Tx) init() {
tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Announcement = NewAnnouncementClient(tx.config)
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
+ tx.AuthIdentity = NewAuthIdentityClient(tx.config)
+ tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config)
+ tx.ChannelMonitor = NewChannelMonitorClient(tx.config)
+ tx.ChannelMonitorDailyRollup = NewChannelMonitorDailyRollupClient(tx.config)
+ tx.ChannelMonitorHistory = NewChannelMonitorHistoryClient(tx.config)
+ tx.ChannelMonitorRequestTemplate = NewChannelMonitorRequestTemplateClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
+ tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config)
tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config)
tx.PaymentOrder = NewPaymentOrderClient(tx.config)
tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config)
+ tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
diff --git a/backend/ent/user.go b/backend/ent/user.go
index 9fa91f74..06670444 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -45,6 +45,12 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
+ // SignupSource holds the value of the "signup_source" field.
+ SignupSource string `json:"signup_source,omitempty"`
+ // LastLoginAt holds the value of the "last_login_at" field.
+ LastLoginAt *time.Time `json:"last_login_at,omitempty"`
+ // LastActiveAt holds the value of the "last_active_at" field.
+ LastActiveAt *time.Time `json:"last_active_at,omitempty"`
// BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
// BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
@@ -55,6 +61,8 @@ type User struct {
BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"`
// TotalRecharged holds the value of the "total_recharged" field.
TotalRecharged float64 `json:"total_recharged,omitempty"`
+ // RpmLimit holds the value of the "rpm_limit" field.
+ RpmLimit int `json:"rpm_limit,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -83,11 +91,15 @@ type UserEdges struct {
PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"`
// PaymentOrders holds the value of the payment_orders edge.
PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"`
+ // AuthIdentities holds the value of the auth_identities edge.
+ AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"`
+ // PendingAuthSessions holds the value of the pending_auth_sessions edge.
+ PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
- loadedTypes [11]bool
+ loadedTypes [13]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@@ -180,10 +192,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) {
return nil, &NotLoadedError{edge: "payment_orders"}
}
+// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge
+// was not loaded in eager-loading.
+func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) {
+ if e.loadedTypes[10] {
+ return e.AuthIdentities, nil
+ }
+ return nil, &NotLoadedError{edge: "auth_identities"}
+}
+
+// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge
+// was not loaded in eager-loading.
+func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) {
+ if e.loadedTypes[11] {
+ return e.PendingAuthSessions, nil
+ }
+ return nil, &NotLoadedError{edge: "pending_auth_sessions"}
+}
+
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
- if e.loadedTypes[10] {
+ if e.loadedTypes[12] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@@ -198,11 +228,11 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged:
values[i] = new(sql.NullFloat64)
- case user.FieldID, user.FieldConcurrency:
+ case user.FieldID, user.FieldConcurrency, user.FieldRpmLimit:
values[i] = new(sql.NullInt64)
- case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
+ case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString)
- case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
+ case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -312,6 +342,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
+ case user.FieldSignupSource:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field signup_source", values[i])
+ } else if value.Valid {
+ _m.SignupSource = value.String
+ }
+ case user.FieldLastLoginAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field last_login_at", values[i])
+ } else if value.Valid {
+ _m.LastLoginAt = new(time.Time)
+ *_m.LastLoginAt = value.Time
+ }
+ case user.FieldLastActiveAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field last_active_at", values[i])
+ } else if value.Valid {
+ _m.LastActiveAt = new(time.Time)
+ *_m.LastActiveAt = value.Time
+ }
case user.FieldBalanceNotifyEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
@@ -343,6 +393,12 @@ func (_m *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.TotalRecharged = value.Float64
}
+ case user.FieldRpmLimit:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field rpm_limit", values[i])
+ } else if value.Valid {
+ _m.RpmLimit = int(value.Int64)
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -406,6 +462,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery {
return NewUserClient(_m.config).QueryPaymentOrders(_m)
}
+// QueryAuthIdentities queries the "auth_identities" edge of the User entity.
+func (_m *User) QueryAuthIdentities() *AuthIdentityQuery {
+ return NewUserClient(_m.config).QueryAuthIdentities(_m)
+}
+
+// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity.
+func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery {
+ return NewUserClient(_m.config).QueryPendingAuthSessions(_m)
+}
+
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
@@ -482,6 +548,19 @@ func (_m *User) String() string {
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
+ builder.WriteString("signup_source=")
+ builder.WriteString(_m.SignupSource)
+ builder.WriteString(", ")
+ if v := _m.LastLoginAt; v != nil {
+ builder.WriteString("last_login_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.LastActiveAt; v != nil {
+ builder.WriteString("last_active_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
builder.WriteString("balance_notify_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
builder.WriteString(", ")
@@ -498,6 +577,9 @@ func (_m *User) String() string {
builder.WriteString(", ")
builder.WriteString("total_recharged=")
builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged))
+ builder.WriteString(", ")
+ builder.WriteString("rpm_limit=")
+ builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit))
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index d88a3a38..e11a8a32 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -43,6 +43,12 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
+ // FieldSignupSource holds the string denoting the signup_source field in the database.
+ FieldSignupSource = "signup_source"
+ // FieldLastLoginAt holds the string denoting the last_login_at field in the database.
+ FieldLastLoginAt = "last_login_at"
+ // FieldLastActiveAt holds the string denoting the last_active_at field in the database.
+ FieldLastActiveAt = "last_active_at"
// FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
FieldBalanceNotifyEnabled = "balance_notify_enabled"
// FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
@@ -53,6 +59,8 @@ const (
FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails"
// FieldTotalRecharged holds the string denoting the total_recharged field in the database.
FieldTotalRecharged = "total_recharged"
+ // FieldRpmLimit holds the string denoting the rpm_limit field in the database.
+ FieldRpmLimit = "rpm_limit"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -73,6 +81,10 @@ const (
EdgePromoCodeUsages = "promo_code_usages"
// EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations.
EdgePaymentOrders = "payment_orders"
+ // EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations.
+ EdgeAuthIdentities = "auth_identities"
+ // EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations.
+ EdgePendingAuthSessions = "pending_auth_sessions"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database.
@@ -145,6 +157,20 @@ const (
PaymentOrdersInverseTable = "payment_orders"
// PaymentOrdersColumn is the table column denoting the payment_orders relation/edge.
PaymentOrdersColumn = "user_id"
+ // AuthIdentitiesTable is the table that holds the auth_identities relation/edge.
+ AuthIdentitiesTable = "auth_identities"
+ // AuthIdentitiesInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ AuthIdentitiesInverseTable = "auth_identities"
+ // AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge.
+ AuthIdentitiesColumn = "user_id"
+ // PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge.
+ PendingAuthSessionsTable = "pending_auth_sessions"
+ // PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity.
+ // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
+ PendingAuthSessionsInverseTable = "pending_auth_sessions"
+ // PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge.
+ PendingAuthSessionsColumn = "target_user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
@@ -171,11 +197,15 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
+ FieldSignupSource,
+ FieldLastLoginAt,
+ FieldLastActiveAt,
FieldBalanceNotifyEnabled,
FieldBalanceNotifyThresholdType,
FieldBalanceNotifyThreshold,
FieldBalanceNotifyExtraEmails,
FieldTotalRecharged,
+ FieldRpmLimit,
}
var (
@@ -232,6 +262,10 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
+ // DefaultSignupSource holds the default value on creation for the "signup_source" field.
+ DefaultSignupSource string
+ // SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
+ SignupSourceValidator func(string) error
// DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
DefaultBalanceNotifyEnabled bool
// DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
@@ -240,6 +274,8 @@ var (
DefaultBalanceNotifyExtraEmails string
// DefaultTotalRecharged holds the default value on creation for the "total_recharged" field.
DefaultTotalRecharged float64
+ // DefaultRpmLimit holds the default value on creation for the "rpm_limit" field.
+ DefaultRpmLimit int
)
// OrderOption defines the ordering options for the User queries.
@@ -320,6 +356,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
+// BySignupSource orders the results by the signup_source field.
+func BySignupSource(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSignupSource, opts...).ToFunc()
+}
+
+// ByLastLoginAt orders the results by the last_login_at field.
+func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc()
+}
+
+// ByLastActiveAt orders the results by the last_active_at field.
+func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc()
+}
+
// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
@@ -345,6 +396,11 @@ func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc()
}
+// ByRpmLimit orders the results by the rpm_limit field.
+func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRpmLimit, opts...).ToFunc()
+}
+
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -485,6 +541,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
+// ByAuthIdentitiesCount orders the results by auth_identities count.
+func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...)
+ }
+}
+
+// ByAuthIdentities orders the results by auth_identities terms.
+func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count.
+func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...)
+ }
+}
+
+// ByPendingAuthSessions orders the results by pending_auth_sessions terms.
+func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -568,6 +652,20 @@ func newPaymentOrdersStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn),
)
}
+func newAuthIdentitiesStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AuthIdentitiesInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
+ )
+}
+func newPendingAuthSessionsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(PendingAuthSessionsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
+ )
+}
func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index 2788aa7a..05d3b35b 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
+// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ.
+func SignupSource(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSignupSource, v))
+}
+
+// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ.
+func LastLoginAt(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
+}
+
+// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ.
+func LastActiveAt(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
+}
+
// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
func BalanceNotifyEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
@@ -150,6 +165,11 @@ func TotalRecharged(v float64) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
}
+// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ.
+func RpmLimit(v int) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldRpmLimit, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -885,6 +905,171 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
+// SignupSourceEQ applies the EQ predicate on the "signup_source" field.
+func SignupSourceEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSignupSource, v))
+}
+
+// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field.
+func SignupSourceNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldSignupSource, v))
+}
+
+// SignupSourceIn applies the In predicate on the "signup_source" field.
+func SignupSourceIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldSignupSource, vs...))
+}
+
+// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field.
+func SignupSourceNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...))
+}
+
+// SignupSourceGT applies the GT predicate on the "signup_source" field.
+func SignupSourceGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldSignupSource, v))
+}
+
+// SignupSourceGTE applies the GTE predicate on the "signup_source" field.
+func SignupSourceGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldSignupSource, v))
+}
+
+// SignupSourceLT applies the LT predicate on the "signup_source" field.
+func SignupSourceLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldSignupSource, v))
+}
+
+// SignupSourceLTE applies the LTE predicate on the "signup_source" field.
+func SignupSourceLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldSignupSource, v))
+}
+
+// SignupSourceContains applies the Contains predicate on the "signup_source" field.
+func SignupSourceContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldSignupSource, v))
+}
+
+// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field.
+func SignupSourceHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v))
+}
+
+// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field.
+func SignupSourceHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v))
+}
+
+// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field.
+func SignupSourceEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldSignupSource, v))
+}
+
+// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field.
+func SignupSourceContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldSignupSource, v))
+}
+
+// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field.
+func LastLoginAtEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
+}
+
+// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field.
+func LastLoginAtNEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v))
+}
+
+// LastLoginAtIn applies the In predicate on the "last_login_at" field.
+func LastLoginAtIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...))
+}
+
+// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field.
+func LastLoginAtNotIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...))
+}
+
+// LastLoginAtGT applies the GT predicate on the "last_login_at" field.
+func LastLoginAtGT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGT(FieldLastLoginAt, v))
+}
+
+// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field.
+func LastLoginAtGTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldLastLoginAt, v))
+}
+
+// LastLoginAtLT applies the LT predicate on the "last_login_at" field.
+func LastLoginAtLT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLT(FieldLastLoginAt, v))
+}
+
+// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field.
+func LastLoginAtLTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldLastLoginAt, v))
+}
+
+// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field.
+func LastLoginAtIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldLastLoginAt))
+}
+
+// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field.
+func LastLoginAtNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldLastLoginAt))
+}
+
+// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field.
+func LastActiveAtEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
+}
+
+// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field.
+func LastActiveAtNEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v))
+}
+
+// LastActiveAtIn applies the In predicate on the "last_active_at" field.
+func LastActiveAtIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...))
+}
+
+// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field.
+func LastActiveAtNotIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...))
+}
+
+// LastActiveAtGT applies the GT predicate on the "last_active_at" field.
+func LastActiveAtGT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGT(FieldLastActiveAt, v))
+}
+
+// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field.
+func LastActiveAtGTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldLastActiveAt, v))
+}
+
+// LastActiveAtLT applies the LT predicate on the "last_active_at" field.
+func LastActiveAtLT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLT(FieldLastActiveAt, v))
+}
+
+// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field.
+func LastActiveAtLTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldLastActiveAt, v))
+}
+
+// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field.
+func LastActiveAtIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldLastActiveAt))
+}
+
+// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field.
+func LastActiveAtNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldLastActiveAt))
+}
+
// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
func BalanceNotifyEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
@@ -1115,6 +1300,46 @@ func TotalRechargedLTE(v float64) predicate.User {
return predicate.User(sql.FieldLTE(FieldTotalRecharged, v))
}
+// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field.
+func RpmLimitEQ(v int) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldRpmLimit, v))
+}
+
+// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field.
+func RpmLimitNEQ(v int) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldRpmLimit, v))
+}
+
+// RpmLimitIn applies the In predicate on the "rpm_limit" field.
+func RpmLimitIn(vs ...int) predicate.User {
+ return predicate.User(sql.FieldIn(FieldRpmLimit, vs...))
+}
+
+// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field.
+func RpmLimitNotIn(vs ...int) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldRpmLimit, vs...))
+}
+
+// RpmLimitGT applies the GT predicate on the "rpm_limit" field.
+func RpmLimitGT(v int) predicate.User {
+ return predicate.User(sql.FieldGT(FieldRpmLimit, v))
+}
+
+// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field.
+func RpmLimitGTE(v int) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldRpmLimit, v))
+}
+
+// RpmLimitLT applies the LT predicate on the "rpm_limit" field.
+func RpmLimitLT(v int) predicate.User {
+ return predicate.User(sql.FieldLT(FieldRpmLimit, v))
+}
+
+// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field.
+func RpmLimitLTE(v int) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldRpmLimit, v))
+}
+
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
@@ -1345,6 +1570,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User {
})
}
+// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge.
+func HasAuthIdentities() predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates).
+func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := newAuthIdentitiesStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge.
+func HasPendingAuthSessions() predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates).
+func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := newPendingAuthSessionsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index fbc64f9c..b4161128 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
+// SetSignupSource sets the "signup_source" field.
+func (_c *UserCreate) SetSignupSource(v string) *UserCreate {
+ _c.mutation.SetSignupSource(v)
+ return _c
+}
+
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate {
+ if v != nil {
+ _c.SetSignupSource(*v)
+ }
+ return _c
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate {
+ _c.mutation.SetLastLoginAt(v)
+ return _c
+}
+
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate {
+ if v != nil {
+ _c.SetLastLoginAt(*v)
+ }
+ return _c
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate {
+ _c.mutation.SetLastActiveAt(v)
+ return _c
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate {
+ if v != nil {
+ _c.SetLastActiveAt(*v)
+ }
+ return _c
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
_c.mutation.SetBalanceNotifyEnabled(v)
@@ -281,6 +325,20 @@ func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate {
return _c
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (_c *UserCreate) SetRpmLimit(v int) *UserCreate {
+ _c.mutation.SetRpmLimit(v)
+ return _c
+}
+
+// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
+func (_c *UserCreate) SetNillableRpmLimit(v *int) *UserCreate {
+ if v != nil {
+ _c.SetRpmLimit(*v)
+ }
+ return _c
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -431,6 +489,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate {
return _c.AddPaymentOrderIDs(ids...)
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate {
+ _c.mutation.AddAuthIdentityIDs(ids...)
+ return _c
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate {
+ _c.mutation.AddPendingAuthSessionIDs(ids...)
+ return _c
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
@@ -510,6 +598,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
+ if _, ok := _c.mutation.SignupSource(); !ok {
+ v := user.DefaultSignupSource
+ _c.mutation.SetSignupSource(v)
+ }
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
v := user.DefaultBalanceNotifyEnabled
_c.mutation.SetBalanceNotifyEnabled(v)
@@ -526,6 +618,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotalRecharged
_c.mutation.SetTotalRecharged(v)
}
+ if _, ok := _c.mutation.RpmLimit(); !ok {
+ v := user.DefaultRpmLimit
+ _c.mutation.SetRpmLimit(v)
+ }
return nil
}
@@ -589,6 +685,14 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
+ if _, ok := _c.mutation.SignupSource(); !ok {
+ return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)}
+ }
+ if v, ok := _c.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
}
@@ -601,6 +705,9 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotalRecharged(); !ok {
return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)}
}
+ if _, ok := _c.mutation.RpmLimit(); !ok {
+ return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "User.rpm_limit"`)}
+ }
return nil
}
@@ -684,6 +791,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
+ if value, ok := _c.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ _node.SignupSource = value
+ }
+ if value, ok := _c.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ _node.LastLoginAt = &value
+ }
+ if value, ok := _c.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ _node.LastActiveAt = &value
+ }
if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
_node.BalanceNotifyEnabled = value
@@ -704,6 +823,10 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
_node.TotalRecharged = value
}
+ if value, ok := _c.mutation.RpmLimit(); ok {
+ _spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
+ _node.RpmLimit = value
+ }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -868,6 +991,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
+ if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
return _node, _spec
}
@@ -1106,6 +1261,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsert) SetSignupSource(v string) *UserUpsert {
+ u.Set(user.FieldSignupSource, v)
+ return u
+}
+
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsert) UpdateSignupSource() *UserUpsert {
+ u.SetExcluded(user.FieldSignupSource)
+ return u
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert {
+ u.Set(user.FieldLastLoginAt, v)
+ return u
+}
+
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert {
+ u.SetExcluded(user.FieldLastLoginAt)
+ return u
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsert) ClearLastLoginAt() *UserUpsert {
+ u.SetNull(user.FieldLastLoginAt)
+ return u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert {
+ u.Set(user.FieldLastActiveAt, v)
+ return u
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert {
+ u.SetExcluded(user.FieldLastActiveAt)
+ return u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsert) ClearLastActiveAt() *UserUpsert {
+ u.SetNull(user.FieldLastActiveAt)
+ return u
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
u.Set(user.FieldBalanceNotifyEnabled, v)
@@ -1184,6 +1387,24 @@ func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert {
return u
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (u *UserUpsert) SetRpmLimit(v int) *UserUpsert {
+ u.Set(user.FieldRpmLimit, v)
+ return u
+}
+
+// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
+func (u *UserUpsert) UpdateRpmLimit() *UserUpsert {
+ u.SetExcluded(user.FieldRpmLimit)
+ return u
+}
+
+// AddRpmLimit adds v to the "rpm_limit" field.
+func (u *UserUpsert) AddRpmLimit(v int) *UserUpsert {
+ u.Add(user.FieldRpmLimit, v)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1446,6 +1667,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSignupSource(v)
+ })
+}
+
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSignupSource()
+ })
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastLoginAt(v)
+ })
+}
+
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastLoginAt()
+ })
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastLoginAt()
+ })
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastActiveAt(v)
+ })
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastActiveAt()
+ })
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastActiveAt()
+ })
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
@@ -1537,6 +1814,27 @@ func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne {
})
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (u *UserUpsertOne) SetRpmLimit(v int) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetRpmLimit(v)
+ })
+}
+
+// AddRpmLimit adds v to the "rpm_limit" field.
+func (u *UserUpsertOne) AddRpmLimit(v int) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.AddRpmLimit(v)
+ })
+}
+
+// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateRpmLimit() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateRpmLimit()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1965,6 +2263,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSignupSource(v)
+ })
+}
+
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSignupSource()
+ })
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastLoginAt(v)
+ })
+}
+
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastLoginAt()
+ })
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastLoginAt()
+ })
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetLastActiveAt(v)
+ })
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastActiveAt()
+ })
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastActiveAt()
+ })
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
@@ -2056,6 +2410,27 @@ func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk {
})
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (u *UserUpsertBulk) SetRpmLimit(v int) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetRpmLimit(v)
+ })
+}
+
+// AddRpmLimit adds v to the "rpm_limit" field.
+func (u *UserUpsertBulk) AddRpmLimit(v int) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.AddRpmLimit(v)
+ })
+}
+
+// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateRpmLimit() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateRpmLimit()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go
index 113d87ac..f1ee5cfe 100644
--- a/backend/ent/user_query.go
+++ b/backend/ent/user_query.go
@@ -15,8 +15,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
@@ -44,6 +46,8 @@ type UserQuery struct {
withAttributeValues *UserAttributeValueQuery
withPromoCodeUsages *PromoCodeUsageQuery
withPaymentOrders *PaymentOrderQuery
+ withAuthIdentities *AuthIdentityQuery
+ withPendingAuthSessions *PendingAuthSessionQuery
withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
@@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery {
return query
}
+// QueryAuthIdentities chains the current query on the "auth_identities" edge.
+func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge.
+func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, selector),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
@@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery {
withAttributeValues: _q.withAttributeValues.Clone(),
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
withPaymentOrders: _q.withPaymentOrders.Clone(),
+ withAuthIdentities: _q.withAuthIdentities.Clone(),
+ withPendingAuthSessions: _q.withPendingAuthSessions.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
@@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu
return _q
}
+// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to
+// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAuthIdentities = query
+ return _q
+}
+
+// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to
+// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withPendingAuthSessions = query
+ return _q
+}
+
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
@@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
- loadedTypes = [11]bool{
+ loadedTypes = [13]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
@@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withAttributeValues != nil,
_q.withPromoCodeUsages != nil,
_q.withPaymentOrders != nil,
+ _q.withAuthIdentities != nil,
+ _q.withPendingAuthSessions != nil,
_q.withUserAllowedGroups != nil,
}
)
@@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
+ if query := _q.withAuthIdentities; query != nil {
+ if err := _q.loadAuthIdentities(ctx, query, nodes,
+ func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} },
+ func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withPendingAuthSessions; query != nil {
+ if err := _q.loadPendingAuthSessions(ctx, query, nodes,
+ func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} },
+ func(n *User, e *PendingAuthSession) {
+ n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
@@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ
}
return nil
}
+func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*User)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(authidentity.FieldUserID)
+ }
+ query.Where(predicate.AuthIdentity(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.UserID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*User)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID)
+ }
+ query.Where(predicate.PendingAuthSession(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.TargetUserID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index 6b355247..f1d759ce 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -13,8 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
@@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
+// SetSignupSource sets the "signup_source" field.
+func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate {
+ _u.mutation.SetSignupSource(v)
+ return _u
+}
+
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate {
+ if v != nil {
+ _u.SetSignupSource(*v)
+ }
+ return _u
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate {
+ _u.mutation.SetLastLoginAt(v)
+ return _u
+}
+
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate {
+ if v != nil {
+ _u.SetLastLoginAt(*v)
+ }
+ return _u
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate {
+ _u.mutation.ClearLastLoginAt()
+ return _u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate {
+ _u.mutation.SetLastActiveAt(v)
+ return _u
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate {
+ if v != nil {
+ _u.SetLastActiveAt(*v)
+ }
+ return _u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate {
+ _u.mutation.ClearLastActiveAt()
+ return _u
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
_u.mutation.SetBalanceNotifyEnabled(v)
@@ -333,6 +389,27 @@ func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate {
return _u
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (_u *UserUpdate) SetRpmLimit(v int) *UserUpdate {
+ _u.mutation.ResetRpmLimit()
+ _u.mutation.SetRpmLimit(v)
+ return _u
+}
+
+// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableRpmLimit(v *int) *UserUpdate {
+ if v != nil {
+ _u.SetRpmLimit(*v)
+ }
+ return _u
+}
+
+// AddRpmLimit adds value to the "rpm_limit" field.
+func (_u *UserUpdate) AddRpmLimit(v int) *UserUpdate {
+ _u.mutation.AddRpmLimit(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -483,6 +560,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.AddPaymentOrderIDs(ids...)
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate {
+ _u.mutation.AddAuthIdentityIDs(ids...)
+ return _u
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate {
+ _u.mutation.AddPendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
@@ -698,6 +805,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate {
return _u.RemovePaymentOrderIDs(ids...)
}
+// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate {
+ _u.mutation.ClearAuthIdentities()
+ return _u
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
+func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate {
+ _u.mutation.RemoveAuthIdentityIDs(ids...)
+ return _u
+}
+
+// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
+func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAuthIdentityIDs(ids...)
+}
+
+// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate {
+ _u.mutation.ClearPendingAuthSessions()
+ return _u
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
+func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate {
+ _u.mutation.RemovePendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
+func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemovePendingAuthSessionIDs(ids...)
+}
+
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@@ -767,6 +916,11 @@ func (_u *UserUpdate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
+ if v, ok := _u.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
return nil
}
@@ -836,6 +990,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastLoginAtCleared() {
+ _spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastActiveAtCleared() {
+ _spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
+ }
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
@@ -860,6 +1029,12 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
}
+ if value, ok := _u.mutation.RpmLimit(); ok {
+ _spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedRpmLimit(); ok {
+ _spec.AddField(user.FieldRpmLimit, field.TypeInt, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1322,6 +1497,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@@ -1548,6 +1813,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
+// SetSignupSource sets the "signup_source" field.
+func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne {
+ _u.mutation.SetSignupSource(v)
+ return _u
+}
+
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne {
+ if v != nil {
+ _u.SetSignupSource(*v)
+ }
+ return _u
+}
+
+// SetLastLoginAt sets the "last_login_at" field.
+func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne {
+ _u.mutation.SetLastLoginAt(v)
+ return _u
+}
+
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne {
+ if v != nil {
+ _u.SetLastLoginAt(*v)
+ }
+ return _u
+}
+
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne {
+ _u.mutation.ClearLastLoginAt()
+ return _u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne {
+ _u.mutation.SetLastActiveAt(v)
+ return _u
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne {
+ if v != nil {
+ _u.SetLastActiveAt(*v)
+ }
+ return _u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne {
+ _u.mutation.ClearLastActiveAt()
+ return _u
+}
+
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
_u.mutation.SetBalanceNotifyEnabled(v)
@@ -1638,6 +1957,27 @@ func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne {
return _u
}
+// SetRpmLimit sets the "rpm_limit" field.
+func (_u *UserUpdateOne) SetRpmLimit(v int) *UserUpdateOne {
+ _u.mutation.ResetRpmLimit()
+ _u.mutation.SetRpmLimit(v)
+ return _u
+}
+
+// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableRpmLimit(v *int) *UserUpdateOne {
+ if v != nil {
+ _u.SetRpmLimit(*v)
+ }
+ return _u
+}
+
+// AddRpmLimit adds value to the "rpm_limit" field.
+func (_u *UserUpdateOne) AddRpmLimit(v int) *UserUpdateOne {
+ _u.mutation.AddRpmLimit(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1788,6 +2128,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne {
return _u.AddPaymentOrderIDs(ids...)
}
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.AddAuthIdentityIDs(ids...)
+ return _u
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.AddPendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
@@ -2003,6 +2373,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne
return _u.RemovePaymentOrderIDs(ids...)
}
+// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne {
+ _u.mutation.ClearAuthIdentities()
+ return _u
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
+func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.RemoveAuthIdentityIDs(ids...)
+ return _u
+}
+
+// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
+func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAuthIdentityIDs(ids...)
+}
+
+// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne {
+ _u.mutation.ClearPendingAuthSessions()
+ return _u
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
+func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.RemovePendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
+func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemovePendingAuthSessionIDs(ids...)
+}
+
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
@@ -2085,6 +2497,11 @@ func (_u *UserUpdateOne) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
+ if v, ok := _u.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
return nil
}
@@ -2171,6 +2588,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastLoginAtCleared() {
+ _spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastActiveAtCleared() {
+ _spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
+ }
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
@@ -2195,6 +2627,12 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
}
+ if value, ok := _u.mutation.RpmLimit(); ok {
+ _spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedRpmLimit(); ok {
+ _spec.AddField(user.FieldRpmLimit, field.TypeInt, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -2657,6 +3095,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
diff --git a/backend/go.mod b/backend/go.mod
index 66b6cc25..982bf91b 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -39,10 +39,11 @@ require (
github.com/wechatpay-apiv3/wechatpay-go v0.2.21
github.com/zeromicro/go-zero v1.9.4
go.uber.org/zap v1.24.0
- golang.org/x/crypto v0.48.0
- golang.org/x/net v0.49.0
- golang.org/x/sync v0.19.0
- golang.org/x/term v0.40.0
+ golang.org/x/crypto v0.49.0
+ golang.org/x/image v0.39.0
+ golang.org/x/net v0.52.0
+ golang.org/x/sync v0.20.0
+ golang.org/x/term v0.41.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
@@ -172,10 +173,10 @@ require (
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
- golang.org/x/mod v0.32.0 // indirect
- golang.org/x/sys v0.41.0 // indirect
- golang.org/x/text v0.34.0 // indirect
- golang.org/x/tools v0.41.0 // indirect
+ golang.org/x/mod v0.34.0 // indirect
+ golang.org/x/sys v0.42.0 // indirect
+ golang.org/x/text v0.36.0 // indirect
+ golang.org/x/tools v0.43.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
diff --git a/backend/go.sum b/backend/go.sum
index 9312af63..0f366ee1 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -413,16 +413,18 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
-golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
-golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
+golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
+golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
-golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
-golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
-golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
-golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
-golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
-golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
+golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww=
+golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA=
+golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
+golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
+golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
+golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
+golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
+golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -432,16 +434,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
-golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
-golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
-golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
-golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
-golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
+golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
+golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
+golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
+golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
+golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
+golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
-golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
-golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
+golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
+golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index dd9a4e58..87263db0 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -52,6 +52,11 @@ const (
ConnectionPoolIsolationAccountProxy = "account_proxy"
)
+// DefaultUpstreamResponseReadMaxBytes 上游非流式响应体的默认读取上限。
+// 128 MB 可容纳 2-3 张 4K PNG(base64 膨胀 33%,单张 4K PNG 最坏约 67MB base64)。
+// 可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。
+const DefaultUpstreamResponseReadMaxBytes int64 = 128 * 1024 * 1024
+
type Config struct {
Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"`
@@ -65,6 +70,7 @@ type Config struct {
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
+ WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
@@ -185,26 +191,47 @@ type LinuxDoConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
+type WeChatConnectConfig struct {
+ Enabled bool `mapstructure:"enabled"`
+ AppID string `mapstructure:"app_id"`
+ AppSecret string `mapstructure:"app_secret"`
+ OpenAppID string `mapstructure:"open_app_id"`
+ OpenAppSecret string `mapstructure:"open_app_secret"`
+ MPAppID string `mapstructure:"mp_app_id"`
+ MPAppSecret string `mapstructure:"mp_app_secret"`
+ MobileAppID string `mapstructure:"mobile_app_id"`
+ MobileAppSecret string `mapstructure:"mobile_app_secret"`
+ OpenEnabled bool `mapstructure:"open_enabled"`
+ MPEnabled bool `mapstructure:"mp_enabled"`
+ MobileEnabled bool `mapstructure:"mobile_enabled"`
+ Mode string `mapstructure:"mode"`
+ Scopes string `mapstructure:"scopes"`
+ RedirectURL string `mapstructure:"redirect_url"`
+ FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
+}
+
type OIDCConnectConfig struct {
- Enabled bool `mapstructure:"enabled"`
- ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
- ClientID string `mapstructure:"client_id"`
- ClientSecret string `mapstructure:"client_secret"`
- IssuerURL string `mapstructure:"issuer_url"`
- DiscoveryURL string `mapstructure:"discovery_url"`
- AuthorizeURL string `mapstructure:"authorize_url"`
- TokenURL string `mapstructure:"token_url"`
- UserInfoURL string `mapstructure:"userinfo_url"`
- JWKSURL string `mapstructure:"jwks_url"`
- Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
- RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
- FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback)
- TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
- UsePKCE bool `mapstructure:"use_pkce"`
- ValidateIDToken bool `mapstructure:"validate_id_token"`
- AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
- ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
- RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
+ Enabled bool `mapstructure:"enabled"`
+ ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
+ ClientID string `mapstructure:"client_id"`
+ ClientSecret string `mapstructure:"client_secret"`
+ IssuerURL string `mapstructure:"issuer_url"`
+ DiscoveryURL string `mapstructure:"discovery_url"`
+ AuthorizeURL string `mapstructure:"authorize_url"`
+ TokenURL string `mapstructure:"token_url"`
+ UserInfoURL string `mapstructure:"userinfo_url"`
+ JWKSURL string `mapstructure:"jwks_url"`
+ Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
+ RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
+ FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback)
+ TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
+ UsePKCE bool `mapstructure:"use_pkce"`
+ ValidateIDToken bool `mapstructure:"validate_id_token"`
+ UsePKCEExplicit bool `mapstructure:"-" yaml:"-"`
+ ValidateIDTokenExplicit bool `mapstructure:"-" yaml:"-"`
+ AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
+ ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
+ RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
@@ -213,6 +240,225 @@ type OIDCConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
+const (
+ defaultWeChatConnectMode = "open"
+ defaultWeChatConnectScopes = "snsapi_login"
+ defaultWeChatConnectFrontendRedirect = "/auth/wechat/callback"
+)
+
+func firstNonEmptyString(values ...string) string {
+ for _, value := range values {
+ if trimmed := strings.TrimSpace(value); trimmed != "" {
+ return trimmed
+ }
+ }
+ return ""
+}
+
+func normalizeWeChatConnectMode(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "mp":
+ return "mp"
+ case "mobile":
+ return "mobile"
+ default:
+ return defaultWeChatConnectMode
+ }
+}
+
+func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string {
+ mode = normalizeWeChatConnectMode(mode)
+ switch mode {
+ case "open":
+ if openEnabled {
+ return "open"
+ }
+ case "mp":
+ if mpEnabled {
+ return "mp"
+ }
+ case "mobile":
+ if mobileEnabled {
+ return "mobile"
+ }
+ }
+ switch {
+ case openEnabled:
+ return "open"
+ case mpEnabled:
+ return "mp"
+ case mobileEnabled:
+ return "mobile"
+ default:
+ return mode
+ }
+}
+
+func defaultWeChatConnectScopesForMode(mode string) string {
+ switch normalizeWeChatConnectMode(mode) {
+ case "mp":
+ return "snsapi_userinfo"
+ case "mobile":
+ return ""
+ default:
+ return defaultWeChatConnectScopes
+ }
+}
+
+func normalizeWeChatConnectScopes(raw, mode string) string {
+ switch normalizeWeChatConnectMode(mode) {
+ case "mp":
+ switch strings.TrimSpace(raw) {
+ case "snsapi_base":
+ return "snsapi_base"
+ case "snsapi_userinfo":
+ return "snsapi_userinfo"
+ default:
+ return defaultWeChatConnectScopesForMode(mode)
+ }
+ case "mobile":
+ return ""
+ default:
+ return defaultWeChatConnectScopes
+ }
+}
+
+func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool {
+ if viper.InConfig(configKey) {
+ return false
+ }
+ _, hasNewEnv := os.LookupEnv(envKey)
+ return !hasNewEnv
+}
+
+func hasExplicitConfigOrEnv(configKey, envKey string) bool {
+ if viper.InConfig(configKey) {
+ return true
+ }
+ _, ok := os.LookupEnv(envKey)
+ return ok
+}
+
+func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) {
+ if cfg == nil {
+ return
+ }
+
+ legacyOpenAppID := ""
+ if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_id", "WECHAT_CONNECT_OPEN_APP_ID") &&
+ shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") {
+ legacyOpenAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID"))
+ if legacyOpenAppID != "" {
+ cfg.OpenAppID = legacyOpenAppID
+ }
+ }
+
+ legacyOpenAppSecret := ""
+ if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_secret", "WECHAT_CONNECT_OPEN_APP_SECRET") &&
+ shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") {
+ legacyOpenAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET"))
+ if legacyOpenAppSecret != "" {
+ cfg.OpenAppSecret = legacyOpenAppSecret
+ }
+ }
+
+ legacyMPAppID := ""
+ if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_id", "WECHAT_CONNECT_MP_APP_ID") &&
+ shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") {
+ legacyMPAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
+ if legacyMPAppID != "" {
+ cfg.MPAppID = legacyMPAppID
+ }
+ }
+
+ legacyMPAppSecret := ""
+ if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_secret", "WECHAT_CONNECT_MP_APP_SECRET") &&
+ shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") {
+ legacyMPAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
+ if legacyMPAppSecret != "" {
+ cfg.MPAppSecret = legacyMPAppSecret
+ }
+ }
+
+ if shouldApplyLegacyWeChatEnv("wechat_connect.frontend_redirect_url", "WECHAT_CONNECT_FRONTEND_REDIRECT_URL") {
+ if legacyFrontend := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")); legacyFrontend != "" {
+ cfg.FrontendRedirectURL = legacyFrontend
+ }
+ }
+
+ hasLegacyOpen := legacyOpenAppID != "" && legacyOpenAppSecret != ""
+ hasLegacyMP := legacyMPAppID != "" && legacyMPAppSecret != ""
+
+ if shouldApplyLegacyWeChatEnv("wechat_connect.enabled", "WECHAT_CONNECT_ENABLED") && (hasLegacyOpen || hasLegacyMP) {
+ cfg.Enabled = true
+ }
+ if shouldApplyLegacyWeChatEnv("wechat_connect.open_enabled", "WECHAT_CONNECT_OPEN_ENABLED") && hasLegacyOpen {
+ cfg.OpenEnabled = true
+ }
+ if shouldApplyLegacyWeChatEnv("wechat_connect.mp_enabled", "WECHAT_CONNECT_MP_ENABLED") && hasLegacyMP {
+ cfg.MPEnabled = true
+ }
+ if shouldApplyLegacyWeChatEnv("wechat_connect.mode", "WECHAT_CONNECT_MODE") {
+ switch {
+ case hasLegacyMP && !hasLegacyOpen:
+ cfg.Mode = "mp"
+ case hasLegacyOpen:
+ cfg.Mode = "open"
+ }
+ }
+ if shouldApplyLegacyWeChatEnv("wechat_connect.scopes", "WECHAT_CONNECT_SCOPES") {
+ switch {
+ case hasLegacyMP && !hasLegacyOpen:
+ cfg.Scopes = defaultWeChatConnectScopesForMode("mp")
+ case hasLegacyOpen:
+ cfg.Scopes = defaultWeChatConnectScopesForMode("open")
+ }
+ }
+}
+
+func normalizeWeChatConnectConfig(cfg *WeChatConnectConfig) {
+ if cfg == nil {
+ return
+ }
+
+ cfg.AppID = strings.TrimSpace(cfg.AppID)
+ cfg.AppSecret = strings.TrimSpace(cfg.AppSecret)
+ cfg.OpenAppID = strings.TrimSpace(cfg.OpenAppID)
+ cfg.OpenAppSecret = strings.TrimSpace(cfg.OpenAppSecret)
+ cfg.MPAppID = strings.TrimSpace(cfg.MPAppID)
+ cfg.MPAppSecret = strings.TrimSpace(cfg.MPAppSecret)
+ cfg.MobileAppID = strings.TrimSpace(cfg.MobileAppID)
+ cfg.MobileAppSecret = strings.TrimSpace(cfg.MobileAppSecret)
+ cfg.Mode = normalizeWeChatConnectMode(cfg.Mode)
+ cfg.RedirectURL = strings.TrimSpace(cfg.RedirectURL)
+ cfg.FrontendRedirectURL = strings.TrimSpace(cfg.FrontendRedirectURL)
+
+ cfg.AppID = firstNonEmptyString(cfg.AppID, cfg.OpenAppID, cfg.MPAppID, cfg.MobileAppID)
+ cfg.AppSecret = firstNonEmptyString(cfg.AppSecret, cfg.OpenAppSecret, cfg.MPAppSecret, cfg.MobileAppSecret)
+ cfg.OpenAppID = firstNonEmptyString(cfg.OpenAppID, cfg.AppID)
+ cfg.OpenAppSecret = firstNonEmptyString(cfg.OpenAppSecret, cfg.AppSecret)
+ cfg.MPAppID = firstNonEmptyString(cfg.MPAppID, cfg.AppID)
+ cfg.MPAppSecret = firstNonEmptyString(cfg.MPAppSecret, cfg.AppSecret)
+ cfg.MobileAppID = firstNonEmptyString(cfg.MobileAppID, cfg.AppID)
+ cfg.MobileAppSecret = firstNonEmptyString(cfg.MobileAppSecret, cfg.AppSecret)
+
+ if !cfg.OpenEnabled && !cfg.MPEnabled && !cfg.MobileEnabled && cfg.Enabled {
+ switch cfg.Mode {
+ case "mp":
+ cfg.MPEnabled = true
+ case "mobile":
+ cfg.MobileEnabled = true
+ default:
+ cfg.OpenEnabled = true
+ }
+ }
+ cfg.Mode = normalizeWeChatConnectStoredMode(cfg.OpenEnabled, cfg.MPEnabled, cfg.MobileEnabled, cfg.Mode)
+ cfg.Scopes = normalizeWeChatConnectScopes(cfg.Scopes, cfg.Mode)
+ if cfg.FrontendRedirectURL == "" {
+ cfg.FrontendRedirectURL = defaultWeChatConnectFrontendRedirect
+ }
+}
+
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
@@ -1007,6 +1253,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
+ applyLegacyWeChatConnectEnvCompatibility(&cfg.WeChat)
+ normalizeWeChatConnectConfig(&cfg.WeChat)
cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName)
cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID)
cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret)
@@ -1024,6 +1272,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
+ cfg.OIDC.UsePKCEExplicit = hasExplicitConfigOrEnv("oidc_connect.use_pkce", "OIDC_CONNECT_USE_PKCE")
+ cfg.OIDC.ValidateIDTokenExplicit = hasExplicitConfigOrEnv("oidc_connect.validate_id_token", "OIDC_CONNECT_VALIDATE_ID_TOKEN")
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
@@ -1202,6 +1452,24 @@ func setDefaults() {
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
+ // WeChat Connect OAuth 登录
+ viper.SetDefault("wechat_connect.enabled", false)
+ viper.SetDefault("wechat_connect.app_id", "")
+ viper.SetDefault("wechat_connect.app_secret", "")
+ viper.SetDefault("wechat_connect.open_app_id", "")
+ viper.SetDefault("wechat_connect.open_app_secret", "")
+ viper.SetDefault("wechat_connect.mp_app_id", "")
+ viper.SetDefault("wechat_connect.mp_app_secret", "")
+ viper.SetDefault("wechat_connect.mobile_app_id", "")
+ viper.SetDefault("wechat_connect.mobile_app_secret", "")
+ viper.SetDefault("wechat_connect.open_enabled", false)
+ viper.SetDefault("wechat_connect.mp_enabled", false)
+ viper.SetDefault("wechat_connect.mobile_enabled", false)
+ viper.SetDefault("wechat_connect.mode", defaultWeChatConnectMode)
+ viper.SetDefault("wechat_connect.scopes", defaultWeChatConnectScopes)
+ viper.SetDefault("wechat_connect.redirect_url", "")
+ viper.SetDefault("wechat_connect.frontend_redirect_url", defaultWeChatConnectFrontendRedirect)
+
// Generic OIDC OAuth 登录
viper.SetDefault("oidc_connect.enabled", false)
viper.SetDefault("oidc_connect.provider_name", "OIDC")
@@ -1217,7 +1485,7 @@ func setDefaults() {
viper.SetDefault("oidc_connect.redirect_url", "")
viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback")
viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post")
- viper.SetDefault("oidc_connect.use_pkce", false)
+ viper.SetDefault("oidc_connect.use_pkce", true)
viper.SetDefault("oidc_connect.validate_id_token", true)
viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256")
viper.SetDefault("oidc_connect.clock_skew_seconds", 120)
@@ -1407,7 +1675,7 @@ func setDefaults() {
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
- viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
+ viper.SetDefault("gateway.upstream_response_read_max_bytes", DefaultUpstreamResponseReadMaxBytes)
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
@@ -1629,9 +1897,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
- if method == "none" && !c.LinuxDo.UsePKCE {
- return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
- }
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
@@ -1662,6 +1927,45 @@ func (c *Config) Validate() error {
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
+ if c.WeChat.Enabled {
+ weChat := c.WeChat
+ normalizeWeChatConnectConfig(&weChat)
+
+ if weChat.OpenEnabled {
+ if strings.TrimSpace(weChat.OpenAppID) == "" {
+ return fmt.Errorf("wechat_connect.open_app_id is required when wechat_connect.open_enabled=true")
+ }
+ if strings.TrimSpace(weChat.OpenAppSecret) == "" {
+ return fmt.Errorf("wechat_connect.open_app_secret is required when wechat_connect.open_enabled=true")
+ }
+ }
+ if weChat.MPEnabled {
+ if strings.TrimSpace(weChat.MPAppID) == "" {
+ return fmt.Errorf("wechat_connect.mp_app_id is required when wechat_connect.mp_enabled=true")
+ }
+ if strings.TrimSpace(weChat.MPAppSecret) == "" {
+ return fmt.Errorf("wechat_connect.mp_app_secret is required when wechat_connect.mp_enabled=true")
+ }
+ }
+ if weChat.MobileEnabled {
+ if strings.TrimSpace(weChat.MobileAppID) == "" {
+ return fmt.Errorf("wechat_connect.mobile_app_id is required when wechat_connect.mobile_enabled=true")
+ }
+ if strings.TrimSpace(weChat.MobileAppSecret) == "" {
+ return fmt.Errorf("wechat_connect.mobile_app_secret is required when wechat_connect.mobile_enabled=true")
+ }
+ }
+ if v := strings.TrimSpace(weChat.RedirectURL); v != "" {
+ if err := ValidateAbsoluteHTTPURL(v); err != nil {
+ return fmt.Errorf("wechat_connect.redirect_url invalid: %w", err)
+ }
+ warnIfInsecureURL("wechat_connect.redirect_url", v)
+ }
+ if err := ValidateFrontendRedirectURL(weChat.FrontendRedirectURL); err != nil {
+ return fmt.Errorf("wechat_connect.frontend_redirect_url invalid: %w", err)
+ }
+ warnIfInsecureURL("wechat_connect.frontend_redirect_url", weChat.FrontendRedirectURL)
+ }
if c.OIDC.Enabled {
if strings.TrimSpace(c.OIDC.ClientID) == "" {
return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true")
@@ -1685,9 +1989,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
- if method == "none" && !c.OIDC.UsePKCE {
- return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none")
- }
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.OIDC.ClientSecret) == "" {
return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index cf58316c..6ba86aa1 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -225,6 +225,52 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
}
}
+func TestLoadWeChatConnectConfigFromLegacyEnv(t *testing.T) {
+ resetViperWithJWTSecret(t)
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+ t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
+ t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret")
+ t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/legacy-callback")
+
+ cfg, err := Load()
+ require.NoError(t, err)
+ require.True(t, cfg.WeChat.Enabled)
+ require.True(t, cfg.WeChat.OpenEnabled)
+ require.True(t, cfg.WeChat.MPEnabled)
+ require.False(t, cfg.WeChat.MobileEnabled)
+ require.Equal(t, "open", cfg.WeChat.Mode)
+ require.Equal(t, "wx-open-app", cfg.WeChat.OpenAppID)
+ require.Equal(t, "wx-open-secret", cfg.WeChat.OpenAppSecret)
+ require.Equal(t, "wx-mp-app", cfg.WeChat.MPAppID)
+ require.Equal(t, "wx-mp-secret", cfg.WeChat.MPAppSecret)
+ require.Equal(t, "/auth/wechat/legacy-callback", cfg.WeChat.FrontendRedirectURL)
+}
+
+func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ require.NoError(t, err)
+ require.True(t, cfg.OIDC.UsePKCE)
+ require.True(t, cfg.OIDC.ValidateIDToken)
+ require.False(t, cfg.OIDC.UsePKCEExplicit)
+ require.False(t, cfg.OIDC.ValidateIDTokenExplicit)
+}
+
+func TestLoadExplicitOIDCSecurityDefaultsFromEnvMarksFlagsExplicit(t *testing.T) {
+ resetViperWithJWTSecret(t)
+ t.Setenv("OIDC_CONNECT_USE_PKCE", "false")
+ t.Setenv("OIDC_CONNECT_VALIDATE_ID_TOKEN", "false")
+
+ cfg, err := Load()
+ require.NoError(t, err)
+ require.False(t, cfg.OIDC.UsePKCE)
+ require.False(t, cfg.OIDC.ValidateIDToken)
+ require.True(t, cfg.OIDC.UsePKCEExplicit)
+ require.True(t, cfg.OIDC.ValidateIDTokenExplicit)
+}
+
func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
resetViperWithJWTSecret(t)
@@ -334,7 +380,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
cfg.LinuxDo.ClientSecret = "test-secret"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
- cfg.LinuxDo.UsePKCE = false
+ cfg.LinuxDo.UsePKCE = true
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
err = cfg.Validate()
@@ -346,7 +392,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
}
}
-func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
+func TestValidateLinuxDoAllowsDisablingPKCEForCompatibility(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
@@ -363,11 +409,8 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
cfg.LinuxDo.UsePKCE = false
err = cfg.Validate()
- if err == nil {
- t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
- }
- if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
- t.Fatalf("Validate() expected use_pkce error, got: %v", err)
+ if err != nil {
+ t.Fatalf("Validate() expected LinuxDo config without PKCE to pass for compatibility, got: %v", err)
}
}
@@ -389,6 +432,7 @@ func TestValidateOIDCScopesMustContainOpenID(t *testing.T) {
cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "profile email"
+ cfg.OIDC.UsePKCE = true
err = cfg.Validate()
if err == nil {
@@ -418,6 +462,7 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "openid email profile"
cfg.OIDC.ValidateIDToken = true
+ cfg.OIDC.UsePKCE = true
err = cfg.Validate()
if err != nil {
@@ -425,6 +470,35 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T
}
}
+func TestValidateOIDCAllowsExplicitCompatibilityOverridesForPKCEAndIDTokenValidation(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ cfg.OIDC.Enabled = true
+ cfg.OIDC.ClientID = "oidc-client"
+ cfg.OIDC.ClientSecret = "oidc-secret"
+ cfg.OIDC.IssuerURL = "https://issuer.example.com"
+ cfg.OIDC.AuthorizeURL = "https://issuer.example.com/auth"
+ cfg.OIDC.TokenURL = "https://issuer.example.com/token"
+ cfg.OIDC.UserInfoURL = "https://issuer.example.com/userinfo"
+ cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
+ cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
+ cfg.OIDC.Scopes = "openid email profile"
+ cfg.OIDC.UsePKCE = false
+ cfg.OIDC.ValidateIDToken = false
+ cfg.OIDC.JWKSURL = ""
+ cfg.OIDC.AllowedSigningAlgs = ""
+
+ err = cfg.Validate()
+ if err != nil {
+ t.Fatalf("Validate() expected OIDC config without PKCE/id_token validation to pass for compatibility, got: %v", err)
+ }
+}
+
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
resetViperWithJWTSecret(t)
@@ -840,6 +914,7 @@ func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
+ cfg.LinuxDo.UsePKCE = true
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() unexpected error: %v", err)
@@ -990,6 +1065,7 @@ func TestValidateConfigErrors(t *testing.T) {
name: "linuxdo client id required",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
+ c.LinuxDo.UsePKCE = true
c.LinuxDo.ClientID = ""
},
wantErr: "linuxdo_connect.client_id",
@@ -998,6 +1074,7 @@ func TestValidateConfigErrors(t *testing.T) {
name: "linuxdo token auth method",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
+ c.LinuxDo.UsePKCE = true
c.LinuxDo.ClientID = "client"
c.LinuxDo.ClientSecret = "secret"
c.LinuxDo.AuthorizeURL = "https://example.com/authorize"
diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go
index a57f7067..27c543dd 100644
--- a/backend/internal/domain/constants.go
+++ b/backend/internal/domain/constants.go
@@ -26,11 +26,12 @@ const (
// Account type constants
const (
- AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
- AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
- AccountTypeAPIKey = "apikey" // API Key类型账号
- AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
- AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
+ AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
+ AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
+ AccountTypeAPIKey = "apikey" // API Key类型账号
+ AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游)
+ AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
+ AccountTypeServiceAccount = "service_account" // Google Service Account 类型账号(用于 Vertex AI)
)
// Redeem type constants
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 9883d007..2d00ccc6 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -98,7 +98,7 @@ type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Notes *string `json:"notes"`
Platform string `json:"platform" binding:"required"`
- Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock"`
+ Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock service_account"`
Credentials map[string]any `json:"credentials" binding:"required"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -117,7 +117,7 @@ type CreateAccountRequest struct {
type UpdateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
- Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock"`
+ Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock service_account"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
@@ -134,19 +134,29 @@ type UpdateAccountRequest struct {
// BulkUpdateAccountsRequest represents the payload for bulk editing accounts
type BulkUpdateAccountsRequest struct {
- AccountIDs []int64 `json:"account_ids" binding:"required,min=1"`
- Name string `json:"name"`
- ProxyID *int64 `json:"proxy_id"`
- Concurrency *int `json:"concurrency"`
- Priority *int `json:"priority"`
- RateMultiplier *float64 `json:"rate_multiplier"`
- LoadFactor *int `json:"load_factor"`
- Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
- Schedulable *bool `json:"schedulable"`
- GroupIDs *[]int64 `json:"group_ids"`
- Credentials map[string]any `json:"credentials"`
- Extra map[string]any `json:"extra"`
- ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
+ AccountIDs []int64 `json:"account_ids"`
+ Filters *BulkUpdateAccountFilters `json:"filters"`
+ Name string `json:"name"`
+ ProxyID *int64 `json:"proxy_id"`
+ Concurrency *int `json:"concurrency"`
+ Priority *int `json:"priority"`
+ RateMultiplier *float64 `json:"rate_multiplier"`
+ LoadFactor *int `json:"load_factor"`
+ Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
+ Schedulable *bool `json:"schedulable"`
+ GroupIDs *[]int64 `json:"group_ids"`
+ Credentials map[string]any `json:"credentials"`
+ Extra map[string]any `json:"extra"`
+ ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
+}
+
+type BulkUpdateAccountFilters struct {
+ Platform string `json:"platform"`
+ Type string `json:"type"`
+ Status string `json:"status"`
+ Group string `json:"group"`
+ Search string `json:"search"`
+ PrivacyMode string `json:"privacy_mode"`
}
// CheckMixedChannelRequest represents check mixed channel risk request
@@ -652,6 +662,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
type TestAccountRequest struct {
ModelID string `json:"model_id"`
Prompt string `json:"prompt"`
+ Mode string `json:"mode"`
}
type SyncFromCRSRequest struct {
@@ -682,7 +693,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
_ = c.ShouldBindJSON(&req)
// Use AccountTestService to test the account with SSE streaming
- if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
+ if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt, req.Mode); err != nil {
// Error already sent via SSE, just log
return
}
@@ -1368,6 +1379,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
+ if len(req.AccountIDs) == 0 && req.Filters == nil {
+ response.BadRequest(c, "account_ids or filters is required")
+ return
+ }
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
@@ -1393,6 +1408,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
result, err := h.adminService.BulkUpdateAccounts(c.Request.Context(), &service.BulkUpdateAccountsInput{
AccountIDs: req.AccountIDs,
+ Filters: toServiceBulkUpdateAccountFilters(req.Filters),
Name: req.Name,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
@@ -1428,6 +1444,20 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.Success(c, result)
}
+func toServiceBulkUpdateAccountFilters(filters *BulkUpdateAccountFilters) *service.BulkUpdateAccountFilters {
+ if filters == nil {
+ return nil
+ }
+ return &service.BulkUpdateAccountFilters{
+ Platform: filters.Platform,
+ Type: filters.Type,
+ Status: filters.Status,
+ Group: filters.Group,
+ Search: filters.Search,
+ PrivacyMode: filters.PrivacyMode,
+ }
+}
+
// ========== OAuth Handlers ==========
// GenerateAuthURLRequest represents the request for generating auth URL
diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go
index 24ec5bcf..929dc240 100644
--- a/backend/internal/handler/admin/account_handler_mixed_channel_test.go
+++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go
@@ -196,3 +196,29 @@ func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
require.Equal(t, float64(2), data["success"])
require.Equal(t, float64(0), data["failed"])
}
+
+func TestBulkUpdateAcceptsFilterTargetRequest(t *testing.T) {
+ adminSvc := newStubAdminService()
+ router := setupAccountMixedChannelRouter(adminSvc)
+
+ body, _ := json.Marshal(map[string]any{
+ "filters": map[string]any{
+ "platform": "openai",
+ "type": "oauth",
+ "status": "active",
+ "group": "12",
+ "privacy_mode": "blocked",
+ "search": "bulk-target",
+ },
+ "schedulable": true,
+ })
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ var resp map[string]any
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, float64(0), resp["code"])
+}
diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go
index cba3ae21..ddeaab02 100644
--- a/backend/internal/handler/admin/admin_basic_handlers_test.go
+++ b/backend/internal/handler/admin/admin_basic_handlers_test.go
@@ -23,6 +23,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
+ router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity)
router.POST("/api/v1/admin/users", userHandler.Create)
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
@@ -75,8 +76,26 @@ func TestUserHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
+ bindBody := map[string]any{
+ "provider_type": "wechat",
+ "provider_key": "wechat-main",
+ "provider_subject": "union-123",
+ "metadata": map[string]any{"source": "admin-repair"},
+ "channel": map[string]any{
+ "channel": "open",
+ "channel_app_id": "wx-open",
+ "channel_subject": "openid-123",
+ },
+ }
+ body, _ := json.Marshal(bindBody)
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
- body, _ := json.Marshal(createBody)
+ body, _ = json.Marshal(createBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
@@ -113,6 +132,33 @@ func TestUserHandlerEndpoints(t *testing.T) {
require.Equal(t, http.StatusOK, rec.Code)
}
+func TestUserHandlerBindAuthIdentityMapsRequest(t *testing.T) {
+ router, adminSvc := setupAdminRouter()
+
+ body, err := json.Marshal(map[string]any{
+ "provider_type": "oidc",
+ "provider_key": "https://issuer.example",
+ "provider_subject": "subject-123",
+ "issuer": "https://issuer.example",
+ "metadata": map[string]any{"report_id": 12},
+ })
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/9/auth-identities", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, int64(9), adminSvc.boundAuthIdentityFor)
+ require.NotNil(t, adminSvc.boundAuthIdentity)
+ require.Equal(t, "oidc", adminSvc.boundAuthIdentity.ProviderType)
+ require.Equal(t, "https://issuer.example", adminSvc.boundAuthIdentity.ProviderKey)
+ require.Equal(t, "subject-123", adminSvc.boundAuthIdentity.ProviderSubject)
+ require.Nil(t, adminSvc.boundAuthIdentity.Channel)
+ require.Equal(t, float64(12), adminSvc.boundAuthIdentity.Metadata["report_id"])
+}
+
func TestGroupHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
diff --git a/backend/internal/handler/admin/admin_helpers_test.go b/backend/internal/handler/admin/admin_helpers_test.go
index 3833d32e..6df49154 100644
--- a/backend/internal/handler/admin/admin_helpers_test.go
+++ b/backend/internal/handler/admin/admin_helpers_test.go
@@ -8,6 +8,7 @@ import (
"testing"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
@@ -222,3 +223,66 @@ func TestOpsWSHelpers(t *testing.T) {
require.True(t, isAddrInTrustedProxies(addr, prefixes))
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
}
+
+// TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier 验证 admin
+// 写入路径会把 ServiceTier 的空字符串/空白/大小写归一化为
+// service.OpenAIFastTierAny ("all"),避免落盘时 "" 与 "all" 双语义。
+func TestOpenAIFastPolicySettingsFromDTO_NormalizesServiceTier(t *testing.T) {
+ t.Run("nil input returns nil", func(t *testing.T) {
+ require.Nil(t, openaiFastPolicySettingsFromDTO(nil))
+ })
+
+ t.Run("empty service_tier becomes 'all'", func(t *testing.T) {
+ in := &dto.OpenAIFastPolicySettings{
+ Rules: []dto.OpenAIFastPolicyRule{{
+ ServiceTier: "",
+ Action: "filter",
+ Scope: "all",
+ }},
+ }
+ out := openaiFastPolicySettingsFromDTO(in)
+ require.NotNil(t, out)
+ require.Len(t, out.Rules, 1)
+ require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
+ require.Equal(t, "all", out.Rules[0].ServiceTier)
+ })
+
+ t.Run("whitespace-only service_tier becomes 'all'", func(t *testing.T) {
+ in := &dto.OpenAIFastPolicySettings{
+ Rules: []dto.OpenAIFastPolicyRule{{
+ ServiceTier: " ",
+ Action: "pass",
+ Scope: "all",
+ }},
+ }
+ out := openaiFastPolicySettingsFromDTO(in)
+ require.Equal(t, service.OpenAIFastTierAny, out.Rules[0].ServiceTier)
+ })
+
+ t.Run("uppercase service_tier is lowercased", func(t *testing.T) {
+ in := &dto.OpenAIFastPolicySettings{
+ Rules: []dto.OpenAIFastPolicyRule{{
+ ServiceTier: "PRIORITY",
+ Action: "filter",
+ Scope: "all",
+ }},
+ }
+ out := openaiFastPolicySettingsFromDTO(in)
+ require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
+ })
+
+ t.Run("non-empty values pass through (lowercased)", func(t *testing.T) {
+ in := &dto.OpenAIFastPolicySettings{
+ Rules: []dto.OpenAIFastPolicyRule{
+ {ServiceTier: "priority", Action: "filter", Scope: "all"},
+ {ServiceTier: "flex", Action: "block", Scope: "oauth"},
+ {ServiceTier: "all", Action: "pass", Scope: "apikey"},
+ },
+ }
+ out := openaiFastPolicySettingsFromDTO(in)
+ require.Len(t, out.Rules, 3)
+ require.Equal(t, service.OpenAIFastTierPriority, out.Rules[0].ServiceTier)
+ require.Equal(t, service.OpenAIFastTierFlex, out.Rules[1].ServiceTier)
+ require.Equal(t, service.OpenAIFastTierAny, out.Rules[2].ServiceTier)
+ })
+}
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index 6d1ef1b6..b187b47f 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -17,6 +17,8 @@ type stubAdminService struct {
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
+ boundAuthIdentity *service.AdminBindAuthIdentityInput
+ boundAuthIdentityFor int64
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
@@ -42,6 +44,14 @@ type stubAdminService struct {
sortOrder string
calls int
}
+ lastListUsers struct {
+ page int
+ pageSize int
+ filters service.UserListFilters
+ sortBy string
+ sortOrder string
+ calls int
+ }
lastListProxies struct {
protocol string
status string
@@ -127,6 +137,12 @@ func newStubAdminService() *stubAdminService {
}
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters, sortBy, sortOrder string) ([]service.User, int64, error) {
+ s.lastListUsers.page = page
+ s.lastListUsers.pageSize = pageSize
+ s.lastListUsers.filters = filters
+ s.lastListUsers.sortBy = sortBy
+ s.lastListUsers.sortOrder = sortOrder
+ s.lastListUsers.calls++
return s.users, int64(len(s.users)), nil
}
@@ -167,6 +183,63 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
return map[string]any{"user_id": userID}, nil
}
+func (s *stubAdminService) GetUserRPMStatus(ctx context.Context, userID int64) (*service.UserRPMStatus, error) {
+ user, err := s.GetUser(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ return &service.UserRPMStatus{
+ UserRPMUsed: 0,
+ UserRPMLimit: user.RPMLimit,
+ }, nil
+}
+
+func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
+ s.boundAuthIdentityFor = userID
+ copied := input
+ if input.Metadata != nil {
+ copied.Metadata = map[string]any{}
+ for key, value := range input.Metadata {
+ copied.Metadata[key] = value
+ }
+ }
+ if input.Channel != nil {
+ channel := *input.Channel
+ if input.Channel.Metadata != nil {
+ channel.Metadata = map[string]any{}
+ for key, value := range input.Channel.Metadata {
+ channel.Metadata[key] = value
+ }
+ }
+ copied.Channel = &channel
+ }
+ s.boundAuthIdentity = &copied
+
+ now := time.Now().UTC()
+ result := &service.AdminBoundAuthIdentity{
+ UserID: userID,
+ ProviderType: input.ProviderType,
+ ProviderKey: input.ProviderKey,
+ ProviderSubject: input.ProviderSubject,
+ VerifiedAt: &now,
+ Issuer: input.Issuer,
+ Metadata: input.Metadata,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if input.Channel != nil {
+ result.Channel = &service.AdminBoundAuthIdentityChannel{
+ Channel: input.Channel.Channel,
+ ChannelAppID: input.Channel.ChannelAppID,
+ ChannelSubject: input.Channel.ChannelSubject,
+ Metadata: input.Channel.Metadata,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ }
+ return result, nil
+}
+
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
@@ -214,6 +287,14 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
return nil
}
+func (s *stubAdminService) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
+ return nil
+}
+
+func (s *stubAdminService) BatchSetGroupRPMOverrides(_ context.Context, _ int64, _ []service.GroupRPMOverrideInput) error {
+ return nil
+}
+
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) {
s.lastListAccounts.platform = platform
s.lastListAccounts.accountType = accountType
@@ -484,6 +565,22 @@ func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return nil, service.ErrAPIKeyNotFound
}
+func (s *stubAdminService) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*service.APIKey, error) {
+ for i := range s.apiKeys {
+ if s.apiKeys[i].ID == keyID {
+ s.apiKeys[i].Usage5h = 0
+ s.apiKeys[i].Usage1d = 0
+ s.apiKeys[i].Usage7d = 0
+ s.apiKeys[i].Window5hStart = nil
+ s.apiKeys[i].Window1dStart = nil
+ s.apiKeys[i].Window7dStart = nil
+ k := s.apiKeys[i]
+ return &k, nil
+ }
+ }
+ return nil, service.ErrAPIKeyNotFound
+}
+
func (s *stubAdminService) ResetAccountQuota(ctx context.Context, id int64) error {
return nil
}
diff --git a/backend/internal/handler/admin/affiliate_handler.go b/backend/internal/handler/admin/affiliate_handler.go
new file mode 100644
index 00000000..97e649ec
--- /dev/null
+++ b/backend/internal/handler/admin/affiliate_handler.go
@@ -0,0 +1,183 @@
+package admin
+
+import (
+ "strconv"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// AffiliateHandler handles admin affiliate (邀请返利) management:
+// listing users with custom settings, updating per-user invite codes
+// and exclusive rebate rates, and batch operations.
+type AffiliateHandler struct {
+ affiliateService *service.AffiliateService
+ adminService service.AdminService
+}
+
+// NewAffiliateHandler creates a new admin affiliate handler.
+func NewAffiliateHandler(affiliateService *service.AffiliateService, adminService service.AdminService) *AffiliateHandler {
+ return &AffiliateHandler{
+ affiliateService: affiliateService,
+ adminService: adminService,
+ }
+}
+
+// ListUsers returns paginated users with custom affiliate settings.
+// GET /api/v1/admin/affiliates/users
+func (h *AffiliateHandler) ListUsers(c *gin.Context) {
+ page, pageSize := response.ParsePagination(c)
+ search := c.Query("search")
+
+ entries, total, err := h.affiliateService.AdminListCustomUsers(c.Request.Context(), service.AffiliateAdminFilter{
+ Search: search,
+ Page: page,
+ PageSize: pageSize,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Paginated(c, entries, total, page, pageSize)
+}
+
+// UpdateUserSettings updates a user's affiliate settings.
+// PUT /api/v1/admin/affiliates/users/:user_id
+//
+// Both fields are optional and applied independently.
+type UpdateAffiliateUserRequest struct {
+ AffCode *string `json:"aff_code"`
+ AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"`
+ // ClearRebateRate explicitly clears the per-user rate (sets it to NULL).
+ // Used to disambiguate from "field not provided".
+ ClearRebateRate bool `json:"clear_rebate_rate"`
+}
+
+func (h *AffiliateHandler) UpdateUserSettings(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
+ if err != nil || userID <= 0 {
+ response.BadRequest(c, "Invalid user_id")
+ return
+ }
+
+ var req UpdateAffiliateUserRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if req.AffCode != nil {
+ if err := h.affiliateService.AdminUpdateUserAffCode(c.Request.Context(), userID, *req.AffCode); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
+ if req.ClearRebateRate {
+ if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ } else if req.AffRebateRatePercent != nil {
+ if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, req.AffRebateRatePercent); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
+ response.Success(c, gin.H{"user_id": userID})
+}
+
+// ClearUserSettings removes ALL of a user's custom affiliate settings — clears
+// the exclusive rebate rate AND regenerates the invite code as a new system
+// random one. Conceptually this "removes the user from the custom list".
+//
+// Both writes happen in this handler; failure of one leaves the other applied,
+// but the operation is idempotent so the admin can re-run it safely.
+// DELETE /api/v1/admin/affiliates/users/:user_id
+func (h *AffiliateHandler) ClearUserSettings(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
+ if err != nil || userID <= 0 {
+ response.BadRequest(c, "Invalid user_id")
+ return
+ }
+ if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if _, err := h.affiliateService.AdminResetUserAffCode(c.Request.Context(), userID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"user_id": userID})
+}
+
+// BatchSetRate applies the same rebate rate (or clears it) to multiple users.
+//
+// Protocol: pass `clear: true` to clear rates (aff_rebate_rate_percent is
+// ignored). Otherwise aff_rebate_rate_percent is required and applied to
+// every user_id. The explicit `clear` flag exists because Go's JSON unmarshal
+// can't distinguish a missing field from `null`, and a silent clear from a
+// frontend that forgot to include the rate would be a footgun.
+//
+// POST /api/v1/admin/affiliates/users/batch-rate
+type BatchSetRateRequest struct {
+ UserIDs []int64 `json:"user_ids" binding:"required"`
+ AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"`
+ Clear bool `json:"clear"`
+}
+
+func (h *AffiliateHandler) BatchSetRate(c *gin.Context) {
+ var req BatchSetRateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ if len(req.UserIDs) == 0 {
+ response.BadRequest(c, "user_ids cannot be empty")
+ return
+ }
+ if !req.Clear && req.AffRebateRatePercent == nil {
+ response.BadRequest(c, "aff_rebate_rate_percent is required unless clear=true")
+ return
+ }
+ rate := req.AffRebateRatePercent
+ if req.Clear {
+ rate = nil
+ }
+ if err := h.affiliateService.AdminBatchSetUserRebateRate(c.Request.Context(), req.UserIDs, rate); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"affected": len(req.UserIDs)})
+}
+
+// AffiliateUserSummary is the minimal user shape returned by LookupUsers,
+// shared with the frontend's add-custom-user picker.
+type AffiliateUserSummary struct {
+ ID int64 `json:"id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+}
+
+// LookupUsers searches users by email/username for the "add custom user" modal.
+// GET /api/v1/admin/affiliates/users/lookup?q=
+func (h *AffiliateHandler) LookupUsers(c *gin.Context) {
+ keyword := c.Query("q")
+ if keyword == "" {
+ response.Success(c, []AffiliateUserSummary{})
+ return
+ }
+ users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 20, service.UserListFilters{Search: keyword}, "email", "asc")
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ result := make([]AffiliateUserSummary, len(users))
+ for i, u := range users {
+ result[i] = AffiliateUserSummary{ID: u.ID, Email: u.Email, Username: u.Username}
+ }
+ response.Success(c, result)
+}
diff --git a/backend/internal/handler/admin/apikey_handler.go b/backend/internal/handler/admin/apikey_handler.go
index 8dd245a4..5e405bdd 100644
--- a/backend/internal/handler/admin/apikey_handler.go
+++ b/backend/internal/handler/admin/apikey_handler.go
@@ -22,12 +22,13 @@ func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandle
}
}
-// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
+// AdminUpdateAPIKeyGroupRequest represents the request to update an API key.
type AdminUpdateAPIKeyGroupRequest struct {
- GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
+ GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
+ ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // true=重置 5h/1d/7d 限速用量
}
-// UpdateGroup handles updating an API key's group binding
+// UpdateGroup handles updating an API key's admin-managed fields.
// PUT /api/v1/admin/api-keys/:id
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
@@ -42,11 +43,23 @@ func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
return
}
+ var resetKey *service.APIKey
+ if req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage {
+ resetKey, err = h.adminService.AdminResetAPIKeyRateLimitUsage(c.Request.Context(), keyID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
+ if resetKey != nil && req.GroupID == nil {
+ result.APIKey = resetKey
+ }
resp := struct {
APIKey *dto.APIKey `json:"api_key"`
diff --git a/backend/internal/handler/admin/apikey_handler_test.go b/backend/internal/handler/admin/apikey_handler_test.go
index bf128b18..6ac6d52f 100644
--- a/backend/internal/handler/admin/apikey_handler_test.go
+++ b/backend/internal/handler/admin/apikey_handler_test.go
@@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -117,6 +118,45 @@ func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
require.Nil(t, resp.Data.APIKey.GroupID)
}
+func TestAdminAPIKeyHandler_ResetRateLimitUsage(t *testing.T) {
+ svc := newStubAdminService()
+ now := time.Now()
+ svc.apiKeys[0].Usage5h = 1.2
+ svc.apiKeys[0].Usage1d = 3.4
+ svc.apiKeys[0].Usage7d = 5.6
+ svc.apiKeys[0].Window5hStart = &now
+ svc.apiKeys[0].Window1dStart = &now
+ svc.apiKeys[0].Window7dStart = &now
+ router := setupAPIKeyHandler(svc)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"reset_rate_limit_usage":true}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var resp struct {
+ Data struct {
+ APIKey struct {
+ Usage5h float64 `json:"usage_5h"`
+ Usage1d float64 `json:"usage_1d"`
+ Usage7d float64 `json:"usage_7d"`
+ Window5hStart *time.Time `json:"window_5h_start"`
+ Window1dStart *time.Time `json:"window_1d_start"`
+ Window7dStart *time.Time `json:"window_7d_start"`
+ } `json:"api_key"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Zero(t, resp.Data.APIKey.Usage5h)
+ require.Zero(t, resp.Data.APIKey.Usage1d)
+ require.Zero(t, resp.Data.APIKey.Usage7d)
+ require.Nil(t, resp.Data.APIKey.Window5hStart)
+ require.Nil(t, resp.Data.APIKey.Window1dStart)
+ require.Nil(t, resp.Data.APIKey.Window7dStart)
+}
+
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go
index 9151d018..950e6e72 100644
--- a/backend/internal/handler/admin/channel_handler.go
+++ b/backend/internal/handler/admin/channel_handler.go
@@ -158,9 +158,6 @@ func channelToResponse(ch *service.Channel) *channelResponse {
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
}
resp.BillingModelSource = ch.BillingModelSource
- if resp.BillingModelSource == "" {
- resp.BillingModelSource = service.BillingModelSourceChannelMapped
- }
if resp.GroupIDs == nil {
resp.GroupIDs = []int64{}
}
diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go
index f218cce4..12cd4bdd 100644
--- a/backend/internal/handler/admin/channel_handler_test.go
+++ b/backend/internal/handler/admin/channel_handler_test.go
@@ -91,7 +91,7 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
ch := &service.Channel{
ID: 1,
Name: "ch",
- BillingModelSource: "",
+ BillingModelSource: service.BillingModelSourceChannelMapped,
CreatedAt: now,
UpdatedAt: now,
GroupIDs: nil,
@@ -105,6 +105,9 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
},
}
+ // handler 层 channelToResponse 现在是纯透传:BillingModelSource 的空值兜底
+ // 已下放到 service 层(Create/GetByID/List/Update/ListAvailable 出口统一处理),
+ // 因此这里构造 fixture 时直接传入归一化后的值。
resp := channelToResponse(ch)
require.Equal(t, "channel_mapped", resp.BillingModelSource)
require.NotNil(t, resp.GroupIDs)
@@ -117,6 +120,19 @@ func TestChannelToResponse_EmptyDefaults(t *testing.T) {
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
}
+func TestChannelToResponse_BillingModelSourcePassthrough(t *testing.T) {
+ // handler 不再兜底 BillingModelSource:空值应原样透传(由 service 层负责默认回填)。
+ ch := &service.Channel{
+ ID: 1,
+ Name: "ch",
+ BillingModelSource: "",
+ CreatedAt: time.Now(),
+ UpdatedAt: time.Now(),
+ }
+ resp := channelToResponse(ch)
+ require.Equal(t, "", resp.BillingModelSource, "handler 应纯透传,默认值由 service.normalizeBillingModelSource 负责")
+}
+
func TestChannelToResponse_NilModels(t *testing.T) {
now := time.Now()
ch := &service.Channel{
diff --git a/backend/internal/handler/admin/channel_monitor_handler.go b/backend/internal/handler/admin/channel_monitor_handler.go
new file mode 100644
index 00000000..e92c81fe
--- /dev/null
+++ b/backend/internal/handler/admin/channel_monitor_handler.go
@@ -0,0 +1,427 @@
+package admin
+
+import (
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ // monitorMaxPageSize 列表分页上限。
+ monitorMaxPageSize = 100
+ // monitorAPIKeyMaskPrefix 脱敏时保留的明文前缀长度。
+ monitorAPIKeyMaskPrefix = 4
+ // monitorAPIKeyMaskSuffix 脱敏后追加的占位字符串。
+ monitorAPIKeyMaskSuffix = "***"
+)
+
+// ChannelMonitorHandler 渠道监控管理后台 handler。
+type ChannelMonitorHandler struct {
+ monitorService *service.ChannelMonitorService
+}
+
+// NewChannelMonitorHandler 创建 handler。
+func NewChannelMonitorHandler(monitorService *service.ChannelMonitorService) *ChannelMonitorHandler {
+ return &ChannelMonitorHandler{monitorService: monitorService}
+}
+
+// --- Request / Response ---
+
+type channelMonitorCreateRequest struct {
+ Name string `json:"name" binding:"required,max=100"`
+ Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
+ Endpoint string `json:"endpoint" binding:"required,max=500"`
+ APIKey string `json:"api_key" binding:"required,max=2000"`
+ PrimaryModel string `json:"primary_model" binding:"required,max=200"`
+ ExtraModels []string `json:"extra_models"`
+ GroupName string `json:"group_name" binding:"max=100"`
+ Enabled *bool `json:"enabled"`
+ IntervalSeconds int `json:"interval_seconds" binding:"required,min=15,max=3600"`
+ TemplateID *int64 `json:"template_id"`
+ ExtraHeaders map[string]string `json:"extra_headers"`
+ BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
+ BodyOverride map[string]any `json:"body_override"`
+}
+
+type channelMonitorUpdateRequest struct {
+ Name *string `json:"name" binding:"omitempty,max=100"`
+ Provider *string `json:"provider" binding:"omitempty,oneof=openai anthropic gemini"`
+ Endpoint *string `json:"endpoint" binding:"omitempty,max=500"`
+ APIKey *string `json:"api_key" binding:"omitempty,max=2000"`
+ PrimaryModel *string `json:"primary_model" binding:"omitempty,max=200"`
+ ExtraModels *[]string `json:"extra_models"`
+ GroupName *string `json:"group_name" binding:"omitempty,max=100"`
+ Enabled *bool `json:"enabled"`
+ IntervalSeconds *int `json:"interval_seconds" binding:"omitempty,min=15,max=3600"`
+ TemplateID *int64 `json:"template_id"`
+ ClearTemplate bool `json:"clear_template"` // true 时把 template_id 置空,忽略 TemplateID
+ ExtraHeaders *map[string]string `json:"extra_headers"`
+ BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
+ BodyOverride *map[string]any `json:"body_override"`
+}
+
+type channelMonitorResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ Endpoint string `json:"endpoint"`
+ APIKeyMasked string `json:"api_key_masked"`
+ APIKeyDecryptFailed bool `json:"api_key_decrypt_failed"`
+ PrimaryModel string `json:"primary_model"`
+ ExtraModels []string `json:"extra_models"`
+ GroupName string `json:"group_name"`
+ Enabled bool `json:"enabled"`
+ IntervalSeconds int `json:"interval_seconds"`
+ LastCheckedAt *string `json:"last_checked_at"`
+ CreatedBy int64 `json:"created_by"`
+ CreatedAt string `json:"created_at"`
+ UpdatedAt string `json:"updated_at"`
+ PrimaryStatus string `json:"primary_status"`
+ PrimaryLatencyMs *int `json:"primary_latency_ms"`
+ Availability7d float64 `json:"availability_7d"`
+ ExtraModelsStatus []dto.ChannelMonitorExtraModelStatus `json:"extra_models_status"`
+ // 请求自定义快照:前端编辑 / 展示「高级设置」用
+ TemplateID *int64 `json:"template_id"`
+ ExtraHeaders map[string]string `json:"extra_headers"`
+ BodyOverrideMode string `json:"body_override_mode"`
+ BodyOverride map[string]any `json:"body_override"`
+}
+
+type channelMonitorCheckResultResponse struct {
+ Model string `json:"model"`
+ Status string `json:"status"`
+ LatencyMs *int `json:"latency_ms"`
+ PingLatencyMs *int `json:"ping_latency_ms"`
+ Message string `json:"message"`
+ CheckedAt string `json:"checked_at"`
+}
+
+type channelMonitorHistoryItemResponse struct {
+ ID int64 `json:"id"`
+ Model string `json:"model"`
+ Status string `json:"status"`
+ LatencyMs *int `json:"latency_ms"`
+ PingLatencyMs *int `json:"ping_latency_ms"`
+ Message string `json:"message"`
+ CheckedAt string `json:"checked_at"`
+}
+
+// maskAPIKey 对 API Key 明文做脱敏:前 4 字符 + "***",长度 ≤ 4 时只显示 "***"。
+func maskAPIKey(plain string) string {
+ if len(plain) <= monitorAPIKeyMaskPrefix {
+ return monitorAPIKeyMaskSuffix
+ }
+ return plain[:monitorAPIKeyMaskPrefix] + monitorAPIKeyMaskSuffix
+}
+
+func channelMonitorToResponse(m *service.ChannelMonitor) *channelMonitorResponse {
+ if m == nil {
+ return nil
+ }
+ extras := m.ExtraModels
+ if extras == nil {
+ extras = []string{}
+ }
+ headers := m.ExtraHeaders
+ if headers == nil {
+ headers = map[string]string{}
+ }
+ resp := &channelMonitorResponse{
+ ID: m.ID,
+ Name: m.Name,
+ Provider: m.Provider,
+ Endpoint: m.Endpoint,
+ APIKeyMasked: maskAPIKey(m.APIKey),
+ APIKeyDecryptFailed: m.APIKeyDecryptFailed,
+ PrimaryModel: m.PrimaryModel,
+ ExtraModels: extras,
+ GroupName: m.GroupName,
+ Enabled: m.Enabled,
+ IntervalSeconds: m.IntervalSeconds,
+ CreatedBy: m.CreatedBy,
+ CreatedAt: m.CreatedAt.UTC().Format(time.RFC3339),
+ UpdatedAt: m.UpdatedAt.UTC().Format(time.RFC3339),
+ TemplateID: m.TemplateID,
+ ExtraHeaders: headers,
+ BodyOverrideMode: m.BodyOverrideMode,
+ BodyOverride: m.BodyOverride,
+ // PrimaryStatus / PrimaryLatencyMs / Availability7d 由 List handler 在批量聚合后填充。
+ }
+ if m.LastCheckedAt != nil {
+ s := m.LastCheckedAt.UTC().Format(time.RFC3339)
+ resp.LastCheckedAt = &s
+ }
+ return resp
+}
+
+func checkResultToResponse(r *service.CheckResult) channelMonitorCheckResultResponse {
+ return channelMonitorCheckResultResponse{
+ Model: r.Model,
+ Status: r.Status,
+ LatencyMs: r.LatencyMs,
+ PingLatencyMs: r.PingLatencyMs,
+ Message: r.Message,
+ CheckedAt: r.CheckedAt.UTC().Format(time.RFC3339),
+ }
+}
+
+func historyEntryToResponse(e *service.ChannelMonitorHistoryEntry) channelMonitorHistoryItemResponse {
+ return channelMonitorHistoryItemResponse{
+ ID: e.ID,
+ Model: e.Model,
+ Status: e.Status,
+ LatencyMs: e.LatencyMs,
+ PingLatencyMs: e.PingLatencyMs,
+ Message: e.Message,
+ CheckedAt: e.CheckedAt.UTC().Format(time.RFC3339),
+ }
+}
+
+// ParseChannelMonitorID 提取并校验路径参数 :id(admin 与 user handler 共享)。
+// 校验失败时已写入 4xx 响应,调用方只需 return。
+func ParseChannelMonitorID(c *gin.Context) (int64, bool) {
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil || id <= 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("INVALID_MONITOR_ID", "invalid monitor id"))
+ return 0, false
+ }
+ return id, true
+}
+
+// parseListEnabled 解析 enabled query 参数:true/false 转为 *bool,空或非法则返回 nil。
+func parseListEnabled(raw string) *bool {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "true", "1", "yes":
+ v := true
+ return &v
+ case "false", "0", "no":
+ v := false
+ return &v
+ default:
+ return nil
+ }
+}
+
+// --- Handlers ---
+
+// List GET /api/v1/admin/channel-monitors
+func (h *ChannelMonitorHandler) List(c *gin.Context) {
+ page, pageSize := response.ParsePagination(c)
+ if pageSize > monitorMaxPageSize {
+ pageSize = monitorMaxPageSize
+ }
+
+ params := service.ChannelMonitorListParams{
+ Page: page,
+ PageSize: pageSize,
+ Provider: strings.TrimSpace(c.Query("provider")),
+ Enabled: parseListEnabled(c.Query("enabled")),
+ Search: strings.TrimSpace(c.Query("search")),
+ }
+
+ items, total, err := h.monitorService.List(c.Request.Context(), params)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ summaries := h.batchSummaryFor(c, items)
+ out := make([]*channelMonitorResponse, 0, len(items))
+ for _, m := range items {
+ out = append(out, buildListItemResponse(m, summaries[m.ID]))
+ }
+ response.Paginated(c, out, total, page, pageSize)
+}
+
+// batchSummaryFor 批量聚合 latest + 7d 可用率,避免每行 2 次 SQL(消除 N+1)。
+func (h *ChannelMonitorHandler) batchSummaryFor(c *gin.Context, items []*service.ChannelMonitor) map[int64]service.MonitorStatusSummary {
+ ids := make([]int64, 0, len(items))
+ primaryByID := make(map[int64]string, len(items))
+ extrasByID := make(map[int64][]string, len(items))
+ for _, m := range items {
+ ids = append(ids, m.ID)
+ primaryByID[m.ID] = m.PrimaryModel
+ extrasByID[m.ID] = m.ExtraModels
+ }
+ return h.monitorService.BatchMonitorStatusSummary(c.Request.Context(), ids, primaryByID, extrasByID)
+}
+
+// buildListItemResponse 把 monitor + summary 装成 admin list 的响应行。
+func buildListItemResponse(m *service.ChannelMonitor, summary service.MonitorStatusSummary) *channelMonitorResponse {
+ resp := channelMonitorToResponse(m)
+ resp.PrimaryStatus = summary.PrimaryStatus
+ resp.PrimaryLatencyMs = summary.PrimaryLatencyMs
+ resp.Availability7d = summary.Availability7d
+ resp.ExtraModelsStatus = make([]dto.ChannelMonitorExtraModelStatus, 0, len(summary.ExtraModels))
+ for _, e := range summary.ExtraModels {
+ resp.ExtraModelsStatus = append(resp.ExtraModelsStatus, dto.ChannelMonitorExtraModelStatus{
+ Model: e.Model,
+ Status: e.Status,
+ LatencyMs: e.LatencyMs,
+ })
+ }
+ return resp
+}
+
+// Get GET /api/v1/admin/channel-monitors/:id
+func (h *ChannelMonitorHandler) Get(c *gin.Context) {
+ id, ok := ParseChannelMonitorID(c)
+ if !ok {
+ return
+ }
+ m, err := h.monitorService.Get(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, channelMonitorToResponse(m))
+}
+
+// Create POST /api/v1/admin/channel-monitors
+func (h *ChannelMonitorHandler) Create(c *gin.Context) {
+ var req channelMonitorCreateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+
+ subject, _ := middleware2.GetAuthSubjectFromContext(c)
+
+ enabled := true
+ if req.Enabled != nil {
+ enabled = *req.Enabled
+ }
+
+ m, err := h.monitorService.Create(c.Request.Context(), service.ChannelMonitorCreateParams{
+ Name: req.Name,
+ Provider: req.Provider,
+ Endpoint: req.Endpoint,
+ APIKey: req.APIKey,
+ PrimaryModel: req.PrimaryModel,
+ ExtraModels: req.ExtraModels,
+ GroupName: req.GroupName,
+ Enabled: enabled,
+ IntervalSeconds: req.IntervalSeconds,
+ CreatedBy: subject.UserID,
+ TemplateID: req.TemplateID,
+ ExtraHeaders: req.ExtraHeaders,
+ BodyOverrideMode: req.BodyOverrideMode,
+ BodyOverride: req.BodyOverride,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Created(c, channelMonitorToResponse(m))
+}
+
+// Update PUT /api/v1/admin/channel-monitors/:id
+func (h *ChannelMonitorHandler) Update(c *gin.Context) {
+ id, ok := ParseChannelMonitorID(c)
+ if !ok {
+ return
+ }
+ var req channelMonitorUpdateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+
+ m, err := h.monitorService.Update(c.Request.Context(), id, service.ChannelMonitorUpdateParams{
+ Name: req.Name,
+ Provider: req.Provider,
+ Endpoint: req.Endpoint,
+ APIKey: req.APIKey,
+ PrimaryModel: req.PrimaryModel,
+ ExtraModels: req.ExtraModels,
+ GroupName: req.GroupName,
+ Enabled: req.Enabled,
+ IntervalSeconds: req.IntervalSeconds,
+ TemplateID: req.TemplateID,
+ ClearTemplate: req.ClearTemplate,
+ ExtraHeaders: req.ExtraHeaders,
+ BodyOverrideMode: req.BodyOverrideMode,
+ BodyOverride: req.BodyOverride,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, channelMonitorToResponse(m))
+}
+
+// Delete DELETE /api/v1/admin/channel-monitors/:id
+func (h *ChannelMonitorHandler) Delete(c *gin.Context) {
+ id, ok := ParseChannelMonitorID(c)
+ if !ok {
+ return
+ }
+ if err := h.monitorService.Delete(c.Request.Context(), id); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, nil)
+}
+
+// Run POST /api/v1/admin/channel-monitors/:id/run
+func (h *ChannelMonitorHandler) Run(c *gin.Context) {
+ id, ok := ParseChannelMonitorID(c)
+ if !ok {
+ return
+ }
+ results, err := h.monitorService.RunCheck(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ out := make([]channelMonitorCheckResultResponse, 0, len(results))
+ for _, r := range results {
+ out = append(out, checkResultToResponse(r))
+ }
+ response.Success(c, gin.H{"results": out})
+}
+
+// History GET /api/v1/admin/channel-monitors/:id/history
+func (h *ChannelMonitorHandler) History(c *gin.Context) {
+ id, ok := ParseChannelMonitorID(c)
+ if !ok {
+ return
+ }
+ limit := parseHistoryLimit(c.Query("limit"))
+ model := strings.TrimSpace(c.Query("model"))
+
+ entries, err := h.monitorService.ListHistory(c.Request.Context(), id, model, limit)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ out := make([]channelMonitorHistoryItemResponse, 0, len(entries))
+ for _, e := range entries {
+ out = append(out, historyEntryToResponse(e))
+ }
+ response.Success(c, gin.H{"items": out})
+}
+
+// parseHistoryLimit 解析 history 接口的 limit query。
+// 使用 service 包的统一上下限常量,避免在 handler 重复定义同名魔法值。
+func parseHistoryLimit(raw string) int {
+ if strings.TrimSpace(raw) == "" {
+ return service.MonitorHistoryDefaultLimit
+ }
+ v, err := strconv.Atoi(raw)
+ if err != nil || v <= 0 {
+ return service.MonitorHistoryDefaultLimit
+ }
+ if v > service.MonitorHistoryMaxLimit {
+ return service.MonitorHistoryMaxLimit
+ }
+ return v
+}
diff --git a/backend/internal/handler/admin/channel_monitor_template_handler.go b/backend/internal/handler/admin/channel_monitor_template_handler.go
new file mode 100644
index 00000000..bebe0929
--- /dev/null
+++ b/backend/internal/handler/admin/channel_monitor_template_handler.go
@@ -0,0 +1,234 @@
+package admin
+
+import (
+ "strconv"
+ "strings"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// ChannelMonitorRequestTemplateHandler 请求模板管理后台 handler。
+type ChannelMonitorRequestTemplateHandler struct {
+ templateService *service.ChannelMonitorRequestTemplateService
+}
+
+// NewChannelMonitorRequestTemplateHandler 创建 handler。
+func NewChannelMonitorRequestTemplateHandler(templateService *service.ChannelMonitorRequestTemplateService) *ChannelMonitorRequestTemplateHandler {
+ return &ChannelMonitorRequestTemplateHandler{templateService: templateService}
+}
+
+// --- DTO ---
+
+type channelMonitorTemplateCreateRequest struct {
+ Name string `json:"name" binding:"required,max=100"`
+ Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
+ Description string `json:"description" binding:"max=500"`
+ ExtraHeaders map[string]string `json:"extra_headers"`
+ BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
+ BodyOverride map[string]any `json:"body_override"`
+}
+
+type channelMonitorTemplateUpdateRequest struct {
+ Name *string `json:"name" binding:"omitempty,max=100"`
+ Description *string `json:"description" binding:"omitempty,max=500"`
+ ExtraHeaders *map[string]string `json:"extra_headers"`
+ BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
+ BodyOverride *map[string]any `json:"body_override"`
+}
+
+type channelMonitorTemplateResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ Description string `json:"description"`
+ ExtraHeaders map[string]string `json:"extra_headers"`
+ BodyOverrideMode string `json:"body_override_mode"`
+ BodyOverride map[string]any `json:"body_override"`
+ CreatedAt string `json:"created_at"`
+ UpdatedAt string `json:"updated_at"`
+ AssociatedMonitors int64 `json:"associated_monitors"`
+}
+
+func (h *ChannelMonitorRequestTemplateHandler) toResponse(c *gin.Context, t *service.ChannelMonitorRequestTemplate) *channelMonitorTemplateResponse {
+ if t == nil {
+ return nil
+ }
+ headers := t.ExtraHeaders
+ if headers == nil {
+ headers = map[string]string{}
+ }
+ count, _ := h.templateService.CountAssociatedMonitors(c.Request.Context(), t.ID)
+ return &channelMonitorTemplateResponse{
+ ID: t.ID,
+ Name: t.Name,
+ Provider: t.Provider,
+ Description: t.Description,
+ ExtraHeaders: headers,
+ BodyOverrideMode: t.BodyOverrideMode,
+ BodyOverride: t.BodyOverride,
+ CreatedAt: t.CreatedAt.UTC().Format(time.RFC3339),
+ UpdatedAt: t.UpdatedAt.UTC().Format(time.RFC3339),
+ AssociatedMonitors: count,
+ }
+}
+
+// parseTemplateID 提取并校验 :id。
+func parseTemplateID(c *gin.Context) (int64, bool) {
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil || id <= 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("INVALID_TEMPLATE_ID", "invalid template id"))
+ return 0, false
+ }
+ return id, true
+}
+
+// --- Handlers ---
+
+// List GET /api/v1/admin/channel-monitor-templates?provider=anthropic
+func (h *ChannelMonitorRequestTemplateHandler) List(c *gin.Context) {
+ items, err := h.templateService.List(c.Request.Context(), service.ChannelMonitorRequestTemplateListParams{
+ Provider: strings.TrimSpace(c.Query("provider")),
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ out := make([]*channelMonitorTemplateResponse, 0, len(items))
+ for _, t := range items {
+ out = append(out, h.toResponse(c, t))
+ }
+ response.Success(c, gin.H{"items": out})
+}
+
+// Get GET /api/v1/admin/channel-monitor-templates/:id
+func (h *ChannelMonitorRequestTemplateHandler) Get(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ t, err := h.templateService.Get(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, h.toResponse(c, t))
+}
+
+// Create POST /api/v1/admin/channel-monitor-templates
+func (h *ChannelMonitorRequestTemplateHandler) Create(c *gin.Context) {
+ var req channelMonitorTemplateCreateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+ t, err := h.templateService.Create(c.Request.Context(), service.ChannelMonitorRequestTemplateCreateParams{
+ Name: req.Name,
+ Provider: req.Provider,
+ Description: req.Description,
+ ExtraHeaders: req.ExtraHeaders,
+ BodyOverrideMode: req.BodyOverrideMode,
+ BodyOverride: req.BodyOverride,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Created(c, h.toResponse(c, t))
+}
+
+// Update PUT /api/v1/admin/channel-monitor-templates/:id
+func (h *ChannelMonitorRequestTemplateHandler) Update(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ var req channelMonitorTemplateUpdateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+ t, err := h.templateService.Update(c.Request.Context(), id, service.ChannelMonitorRequestTemplateUpdateParams{
+ Name: req.Name,
+ Description: req.Description,
+ ExtraHeaders: req.ExtraHeaders,
+ BodyOverrideMode: req.BodyOverrideMode,
+ BodyOverride: req.BodyOverride,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, h.toResponse(c, t))
+}
+
+// Delete DELETE /api/v1/admin/channel-monitor-templates/:id
+func (h *ChannelMonitorRequestTemplateHandler) Delete(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ if err := h.templateService.Delete(c.Request.Context(), id); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, nil)
+}
+
+type channelMonitorTemplateApplyRequest struct {
+ // MonitorIDs 必填、非空:用户在 picker 里勾选的要被覆盖的监控 ID 列表。
+ // 仅当对应监控当前 template_id == :id 时才会真的被覆盖。
+ MonitorIDs []int64 `json:"monitor_ids" binding:"required,min=1"`
+}
+
+// Apply POST /api/v1/admin/channel-monitor-templates/:id/apply
+// 把模板当前配置覆盖到 monitor_ids 列表里的关联监控(picker 选中的子集)。
+func (h *ChannelMonitorRequestTemplateHandler) Apply(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ var req channelMonitorTemplateApplyRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+ affected, err := h.templateService.ApplyToMonitors(c.Request.Context(), id, req.MonitorIDs)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"affected": affected})
+}
+
+type associatedMonitorBriefResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ Enabled bool `json:"enabled"`
+}
+
+// AssociatedMonitors GET /api/v1/admin/channel-monitor-templates/:id/monitors
+// 列出关联监控(picker 弹窗用)。
+func (h *ChannelMonitorRequestTemplateHandler) AssociatedMonitors(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ items, err := h.templateService.ListAssociatedMonitors(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ out := make([]associatedMonitorBriefResponse, 0, len(items))
+ for _, m := range items {
+ out = append(out, associatedMonitorBriefResponse{
+ ID: m.ID, Name: m.Name, Provider: m.Provider, Enabled: m.Enabled,
+ })
+ }
+ response.Success(c, gin.H{"items": out})
+}
diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go
index cb2bd201..65e5ec78 100644
--- a/backend/internal/handler/admin/group_handler.go
+++ b/backend/internal/handler/admin/group_handler.go
@@ -110,6 +110,8 @@ type CreateGroupRequest struct {
RequirePrivacySet bool `json:"require_privacy_set"`
DefaultMappedModel string `json:"default_mapped_model"`
MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
+ // 分组 RPM 上限(0 = 不限制)
+ RPMLimit int `json:"rpm_limit"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -145,6 +147,8 @@ type UpdateGroupRequest struct {
RequirePrivacySet *bool `json:"require_privacy_set"`
DefaultMappedModel *string `json:"default_mapped_model"`
MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
+ // 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动
+ RPMLimit *int `json:"rpm_limit"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
+ RPMLimit: req.RPMLimit,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
@@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel,
MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
+ RPMLimit: req.RPMLimit,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
@@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
}
+// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request
+type BatchSetGroupRPMOverridesRequest struct {
+ Entries []service.GroupRPMOverrideInput `json:"entries" binding:"required"`
+}
+
+// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group
+// PUT /api/v1/admin/groups/:id/rpm-overrides
+func (h *GroupHandler) BatchSetGroupRPMOverrides(c *gin.Context) {
+ groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid group ID")
+ return
+ }
+
+ var req BatchSetGroupRPMOverridesRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.adminService.BatchSetGroupRPMOverrides(c.Request.Context(), groupID, req.Entries); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "RPM overrides updated successfully"})
+}
+
+// ClearGroupRPMOverrides handles clearing all rpm_override for a group
+// DELETE /api/v1/admin/groups/:id/rpm-overrides
+func (h *GroupHandler) ClearGroupRPMOverrides(c *gin.Context) {
+ groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid group ID")
+ return
+ }
+
+ if err := h.adminService.ClearGroupRPMOverrides(c.Request.Context(), groupID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "RPM overrides cleared successfully"})
+}
+
// UpdateSortOrderRequest represents the request to update group sort orders
type UpdateSortOrderRequest struct {
Updates []struct {
diff --git a/backend/internal/handler/admin/payment_handler.go b/backend/internal/handler/admin/payment_handler.go
index b0ed6aed..84359cd9 100644
--- a/backend/internal/handler/admin/payment_handler.go
+++ b/backend/internal/handler/admin/payment_handler.go
@@ -3,6 +3,7 @@ package admin
import (
"strconv"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -66,7 +67,7 @@ func (h *PaymentHandler) ListOrders(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Paginated(c, orders, int64(total), page, pageSize)
+ response.Paginated(c, sanitizeAdminPaymentOrdersForResponse(orders), int64(total), page, pageSize)
}
// GetOrderDetail returns detailed information about a single order.
@@ -82,7 +83,7 @@ func (h *PaymentHandler) GetOrderDetail(c *gin.Context) {
return
}
auditLogs, _ := h.paymentService.GetOrderAuditLogs(c.Request.Context(), orderID)
- response.Success(c, gin.H{"order": order, "auditLogs": auditLogs})
+ response.Success(c, gin.H{"order": sanitizeAdminPaymentOrderForResponse(order), "auditLogs": auditLogs})
}
// CancelOrder cancels a pending order (admin).
@@ -114,6 +115,26 @@ func (h *PaymentHandler) RetryFulfillment(c *gin.Context) {
response.Success(c, gin.H{"message": "fulfillment retried"})
}
+func sanitizeAdminPaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
+ if len(orders) == 0 {
+ return orders
+ }
+ out := make([]*dbent.PaymentOrder, 0, len(orders))
+ for _, order := range orders {
+ out = append(out, sanitizeAdminPaymentOrderForResponse(order))
+ }
+ return out
+}
+
+func sanitizeAdminPaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
+ if order == nil {
+ return nil
+ }
+ cloned := *order
+ cloned.ProviderSnapshot = nil
+ return &cloned
+}
+
// AdminProcessRefundRequest is the request body for admin refund processing.
type AdminProcessRefundRequest struct {
Amount float64 `json:"amount"`
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index bec0f126..59f4fe85 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -43,6 +43,15 @@ func scopesContainOpenID(scopes string) bool {
return false
}
+func firstNonEmpty(values ...string) string {
+ for _, value := range values {
+ if trimmed := strings.TrimSpace(value); trimmed != "" {
+ return trimmed
+ }
+ }
+ return ""
+}
+
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
@@ -73,6 +82,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
@@ -93,114 +107,191 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
paymentCfg = &service.PaymentConfig{}
}
- response.Success(c, dto.SystemSettings{
- RegistrationEnabled: settings.RegistrationEnabled,
- EmailVerifyEnabled: settings.EmailVerifyEnabled,
- RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
- PromoCodeEnabled: settings.PromoCodeEnabled,
- PasswordResetEnabled: settings.PasswordResetEnabled,
- FrontendURL: settings.FrontendURL,
- InvitationCodeEnabled: settings.InvitationCodeEnabled,
- TotpEnabled: settings.TotpEnabled,
- TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
- SMTPHost: settings.SMTPHost,
- SMTPPort: settings.SMTPPort,
- SMTPUsername: settings.SMTPUsername,
- SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
- SMTPFrom: settings.SMTPFrom,
- SMTPFromName: settings.SMTPFromName,
- SMTPUseTLS: settings.SMTPUseTLS,
- TurnstileEnabled: settings.TurnstileEnabled,
- TurnstileSiteKey: settings.TurnstileSiteKey,
- TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
- LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
- LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
- LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
- LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
- OIDCConnectEnabled: settings.OIDCConnectEnabled,
- OIDCConnectProviderName: settings.OIDCConnectProviderName,
- OIDCConnectClientID: settings.OIDCConnectClientID,
- OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured,
- OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL,
- OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL,
- OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL,
- OIDCConnectTokenURL: settings.OIDCConnectTokenURL,
- OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL,
- OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL,
- OIDCConnectScopes: settings.OIDCConnectScopes,
- OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL,
- OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL,
- OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod,
- OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE,
- OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken,
- OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs,
- OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds,
- OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified,
- OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
- OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
- OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
- SiteName: settings.SiteName,
- SiteLogo: settings.SiteLogo,
- SiteSubtitle: settings.SiteSubtitle,
- APIBaseURL: settings.APIBaseURL,
- ContactInfo: settings.ContactInfo,
- DocURL: settings.DocURL,
- HomeContent: settings.HomeContent,
- HideCcsImportButton: settings.HideCcsImportButton,
- PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
- PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
- TableDefaultPageSize: settings.TableDefaultPageSize,
- TablePageSizeOptions: settings.TablePageSizeOptions,
- CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
- CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
- DefaultConcurrency: settings.DefaultConcurrency,
- DefaultBalance: settings.DefaultBalance,
- DefaultSubscriptions: defaultSubscriptions,
- EnableModelFallback: settings.EnableModelFallback,
- FallbackModelAnthropic: settings.FallbackModelAnthropic,
- FallbackModelOpenAI: settings.FallbackModelOpenAI,
- FallbackModelGemini: settings.FallbackModelGemini,
- FallbackModelAntigravity: settings.FallbackModelAntigravity,
- EnableIdentityPatch: settings.EnableIdentityPatch,
- IdentityPatchPrompt: settings.IdentityPatchPrompt,
- OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
- OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
- OpsQueryModeDefault: settings.OpsQueryModeDefault,
- OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
- MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
- MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion,
- AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
- BackendModeEnabled: settings.BackendModeEnabled,
- EnableFingerprintUnification: settings.EnableFingerprintUnification,
- EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
- EnableCCHSigning: settings.EnableCCHSigning,
- WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
- BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
- BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
- BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
- AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
- AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
- PaymentEnabled: paymentCfg.Enabled,
- PaymentMinAmount: paymentCfg.MinAmount,
- PaymentMaxAmount: paymentCfg.MaxAmount,
- PaymentDailyLimit: paymentCfg.DailyLimit,
- PaymentOrderTimeoutMin: paymentCfg.OrderTimeoutMin,
- PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders,
- PaymentEnabledTypes: paymentCfg.EnabledTypes,
- PaymentBalanceDisabled: paymentCfg.BalanceDisabled,
- PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier,
- PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate,
- PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy,
- PaymentProductNamePrefix: paymentCfg.ProductNamePrefix,
- PaymentProductNameSuffix: paymentCfg.ProductNameSuffix,
- PaymentHelpImageURL: paymentCfg.HelpImageURL,
- PaymentHelpText: paymentCfg.HelpText,
- PaymentCancelRateLimitEnabled: paymentCfg.CancelRateLimitEnabled,
- PaymentCancelRateLimitMax: paymentCfg.CancelRateLimitMax,
- PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
- PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
- PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
- })
+ payload := dto.SystemSettings{
+ RegistrationEnabled: settings.RegistrationEnabled,
+ EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
+ PromoCodeEnabled: settings.PromoCodeEnabled,
+ PasswordResetEnabled: settings.PasswordResetEnabled,
+ FrontendURL: settings.FrontendURL,
+ InvitationCodeEnabled: settings.InvitationCodeEnabled,
+ TotpEnabled: settings.TotpEnabled,
+ TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
+ SMTPHost: settings.SMTPHost,
+ SMTPPort: settings.SMTPPort,
+ SMTPUsername: settings.SMTPUsername,
+ SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
+ SMTPFrom: settings.SMTPFrom,
+ SMTPFromName: settings.SMTPFromName,
+ SMTPUseTLS: settings.SMTPUseTLS,
+ TurnstileEnabled: settings.TurnstileEnabled,
+ TurnstileSiteKey: settings.TurnstileSiteKey,
+ TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
+ LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
+ LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
+ LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
+ LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
+ WeChatConnectEnabled: settings.WeChatConnectEnabled,
+ WeChatConnectAppID: settings.WeChatConnectAppID,
+ WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured,
+ WeChatConnectOpenAppID: settings.WeChatConnectOpenAppID,
+ WeChatConnectOpenAppSecretConfigured: settings.WeChatConnectOpenAppSecretConfigured,
+ WeChatConnectMPAppID: settings.WeChatConnectMPAppID,
+ WeChatConnectMPAppSecretConfigured: settings.WeChatConnectMPAppSecretConfigured,
+ WeChatConnectMobileAppID: settings.WeChatConnectMobileAppID,
+ WeChatConnectMobileAppSecretConfigured: settings.WeChatConnectMobileAppSecretConfigured,
+ WeChatConnectOpenEnabled: settings.WeChatConnectOpenEnabled,
+ WeChatConnectMPEnabled: settings.WeChatConnectMPEnabled,
+ WeChatConnectMobileEnabled: settings.WeChatConnectMobileEnabled,
+ WeChatConnectMode: settings.WeChatConnectMode,
+ WeChatConnectScopes: settings.WeChatConnectScopes,
+ WeChatConnectRedirectURL: settings.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: settings.WeChatConnectFrontendRedirectURL,
+ OIDCConnectEnabled: settings.OIDCConnectEnabled,
+ OIDCConnectProviderName: settings.OIDCConnectProviderName,
+ OIDCConnectClientID: settings.OIDCConnectClientID,
+ OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured,
+ OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL,
+ OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL,
+ OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL,
+ OIDCConnectTokenURL: settings.OIDCConnectTokenURL,
+ OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL,
+ OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL,
+ OIDCConnectScopes: settings.OIDCConnectScopes,
+ OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL,
+ OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL,
+ OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod,
+ OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE,
+ OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken,
+ OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs,
+ OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds,
+ OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified,
+ OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
+ OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
+ OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
+ SiteName: settings.SiteName,
+ SiteLogo: settings.SiteLogo,
+ SiteSubtitle: settings.SiteSubtitle,
+ APIBaseURL: settings.APIBaseURL,
+ ContactInfo: settings.ContactInfo,
+ DocURL: settings.DocURL,
+ HomeContent: settings.HomeContent,
+ HideCcsImportButton: settings.HideCcsImportButton,
+ PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
+ PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
+ TableDefaultPageSize: settings.TableDefaultPageSize,
+ TablePageSizeOptions: settings.TablePageSizeOptions,
+ CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
+ CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
+ DefaultConcurrency: settings.DefaultConcurrency,
+ DefaultBalance: settings.DefaultBalance,
+ AffiliateRebateRate: settings.AffiliateRebateRate,
+ AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
+ AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
+ AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap,
+ DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
+ DefaultSubscriptions: defaultSubscriptions,
+ EnableModelFallback: settings.EnableModelFallback,
+ FallbackModelAnthropic: settings.FallbackModelAnthropic,
+ FallbackModelOpenAI: settings.FallbackModelOpenAI,
+ FallbackModelGemini: settings.FallbackModelGemini,
+ FallbackModelAntigravity: settings.FallbackModelAntigravity,
+ EnableIdentityPatch: settings.EnableIdentityPatch,
+ IdentityPatchPrompt: settings.IdentityPatchPrompt,
+ OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
+ OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
+ OpsQueryModeDefault: settings.OpsQueryModeDefault,
+ OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
+ MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
+ MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion,
+ AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
+ BackendModeEnabled: settings.BackendModeEnabled,
+ EnableFingerprintUnification: settings.EnableFingerprintUnification,
+ EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
+ EnableCCHSigning: settings.EnableCCHSigning,
+ EnableAnthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
+ WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
+ PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
+ PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
+ PaymentVisibleMethodAlipayEnabled: settings.PaymentVisibleMethodAlipayEnabled,
+ PaymentVisibleMethodWxpayEnabled: settings.PaymentVisibleMethodWxpayEnabled,
+ OpenAIAdvancedSchedulerEnabled: settings.OpenAIAdvancedSchedulerEnabled,
+ BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
+ BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
+ BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
+ AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
+ AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
+ PaymentEnabled: paymentCfg.Enabled,
+ PaymentMinAmount: paymentCfg.MinAmount,
+ PaymentMaxAmount: paymentCfg.MaxAmount,
+ PaymentDailyLimit: paymentCfg.DailyLimit,
+ PaymentOrderTimeoutMin: paymentCfg.OrderTimeoutMin,
+ PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders,
+ PaymentEnabledTypes: paymentCfg.EnabledTypes,
+ PaymentBalanceDisabled: paymentCfg.BalanceDisabled,
+ PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier,
+ PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate,
+ PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy,
+ PaymentProductNamePrefix: paymentCfg.ProductNamePrefix,
+ PaymentProductNameSuffix: paymentCfg.ProductNameSuffix,
+ PaymentHelpImageURL: paymentCfg.HelpImageURL,
+ PaymentHelpText: paymentCfg.HelpText,
+ PaymentCancelRateLimitEnabled: paymentCfg.CancelRateLimitEnabled,
+ PaymentCancelRateLimitMax: paymentCfg.CancelRateLimitMax,
+ PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
+ PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
+ PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
+
+ ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
+ ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
+
+ AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
+
+ AffiliateEnabled: settings.AffiliateEnabled,
+ }
+
+ // OpenAI fast policy (stored under a dedicated setting key)
+ if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
+ slog.Error("openai_fast_policy_settings_get_failed", "error", err)
+ } else if fastPolicy != nil {
+ payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
+ }
+
+ response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
+}
+
+// openaiFastPolicySettingsToDTO converts service -> dto for OpenAI fast policy.
+func openaiFastPolicySettingsToDTO(s *service.OpenAIFastPolicySettings) *dto.OpenAIFastPolicySettings {
+ if s == nil {
+ return nil
+ }
+ rules := make([]dto.OpenAIFastPolicyRule, len(s.Rules))
+ for i, r := range s.Rules {
+ rules[i] = dto.OpenAIFastPolicyRule(r)
+ }
+ return &dto.OpenAIFastPolicySettings{Rules: rules}
+}
+
+// openaiFastPolicySettingsFromDTO converts dto -> service for OpenAI fast policy.
+//
+// 规范化 ServiceTier:在 DTO 进入 service 层之前统一把空字符串归一为
+// service.OpenAIFastTierAny ("all"),避免管理员保存时空串与 "all" 同时
+// 表达"匹配任意 tier"造成数据库取值的二义性。其它非空值原样透传,由
+// service.SetOpenAIFastPolicySettings 负责合法值校验。
+func openaiFastPolicySettingsFromDTO(s *dto.OpenAIFastPolicySettings) *service.OpenAIFastPolicySettings {
+ if s == nil {
+ return nil
+ }
+ rules := make([]service.OpenAIFastPolicyRule, len(s.Rules))
+ for i, r := range s.Rules {
+ rules[i] = service.OpenAIFastPolicyRule(r)
+ tier := strings.ToLower(strings.TrimSpace(rules[i].ServiceTier))
+ if tier == "" {
+ tier = service.OpenAIFastTierAny
+ }
+ rules[i].ServiceTier = tier
+ }
+ return &service.OpenAIFastPolicySettings{Rules: rules}
}
// UpdateSettingsRequest 更新设置请求
@@ -235,6 +326,24 @@ type UpdateSettingsRequest struct {
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
+ // WeChat Connect OAuth 登录
+ WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
+ WeChatConnectAppID string `json:"wechat_connect_app_id"`
+ WeChatConnectAppSecret string `json:"wechat_connect_app_secret"`
+ WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"`
+ WeChatConnectOpenAppSecret string `json:"wechat_connect_open_app_secret"`
+ WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"`
+ WeChatConnectMPAppSecret string `json:"wechat_connect_mp_app_secret"`
+ WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"`
+ WeChatConnectMobileAppSecret string `json:"wechat_connect_mobile_app_secret"`
+ WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"`
+ WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"`
+ WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"`
+ WeChatConnectMode string `json:"wechat_connect_mode"`
+ WeChatConnectScopes string `json:"wechat_connect_scopes"`
+ WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"`
+ WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"`
+
// Generic OIDC OAuth 登录
OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
@@ -250,8 +359,8 @@ type UpdateSettingsRequest struct {
OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"`
OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"`
OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"`
- OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"`
- OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"`
+ OIDCConnectUsePKCE *bool `json:"oidc_connect_use_pkce"`
+ OIDCConnectValidateIDToken *bool `json:"oidc_connect_validate_id_token"`
OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"`
OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"`
OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"`
@@ -276,9 +385,35 @@ type UpdateSettingsRequest struct {
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
- DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
+ AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
+ AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
+ AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
+ AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
+ DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
+ DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
+ AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
+ AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
+ AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
+ AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
+ AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
+ AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
+ AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
+ AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
+ AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
+ AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
+ AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
+ AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
+ AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
+ AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
+ AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
+ AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
+ AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
+ AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
+ AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
+ AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
+ ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -307,9 +442,19 @@ type UpdateSettingsRequest struct {
BackendModeEnabled bool `json:"backend_mode_enabled"`
// Gateway forwarding behavior
- EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
- EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
- EnableCCHSigning *bool `json:"enable_cch_signing"`
+ EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
+ EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
+ EnableCCHSigning *bool `json:"enable_cch_signing"`
+ EnableAnthropicCacheTTL1hInjection *bool `json:"enable_anthropic_cache_ttl_1h_injection"`
+
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
+ PaymentVisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"`
+ PaymentVisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"`
+ PaymentVisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"`
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"`
// Balance low notification
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
@@ -341,6 +486,19 @@ type UpdateSettingsRequest struct {
PaymentCancelRateLimitWindow *int `json:"payment_cancel_rate_limit_window"`
PaymentCancelRateLimitUnit *string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode *string `json:"payment_cancel_rate_limit_window_mode"`
+
+ // Channel Monitor feature switch
+ ChannelMonitorEnabled *bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds *int `json:"channel_monitor_default_interval_seconds"`
+
+ // Available Channels feature switch (user-facing)
+ AvailableChannelsEnabled *bool `json:"available_channels_enabled"`
+
+ // Affiliate (邀请返利) feature switch
+ AffiliateEnabled *bool `json:"affiliate_enabled"`
+
+ // OpenAI fast/flex policy (optional, only updated when provided)
+ OpenAIFastPolicySettings *dto.OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
}
// UpdateSettings 更新系统设置
@@ -357,6 +515,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
// 验证参数
if req.DefaultConcurrency < 1 {
@@ -365,6 +528,43 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 {
req.DefaultBalance = 0
}
+ affiliateRebateRate := previousSettings.AffiliateRebateRate
+ if req.AffiliateRebateRate != nil {
+ affiliateRebateRate = *req.AffiliateRebateRate
+ }
+ if affiliateRebateRate < service.AffiliateRebateRateMin {
+ affiliateRebateRate = service.AffiliateRebateRateMin
+ }
+ if affiliateRebateRate > service.AffiliateRebateRateMax {
+ affiliateRebateRate = service.AffiliateRebateRateMax
+ }
+ affiliateRebateFreezeHours := previousSettings.AffiliateRebateFreezeHours
+ if req.AffiliateRebateFreezeHours != nil {
+ affiliateRebateFreezeHours = *req.AffiliateRebateFreezeHours
+ }
+ if affiliateRebateFreezeHours < 0 {
+ affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursDefault
+ }
+ if affiliateRebateFreezeHours > service.AffiliateRebateFreezeHoursMax {
+ affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursMax
+ }
+ affiliateRebateDurationDays := previousSettings.AffiliateRebateDurationDays
+ if req.AffiliateRebateDurationDays != nil {
+ affiliateRebateDurationDays = *req.AffiliateRebateDurationDays
+ }
+ if affiliateRebateDurationDays < 0 {
+ affiliateRebateDurationDays = service.AffiliateRebateDurationDaysDefault
+ }
+ if affiliateRebateDurationDays > service.AffiliateRebateDurationDaysMax {
+ affiliateRebateDurationDays = service.AffiliateRebateDurationDaysMax
+ }
+ affiliateRebatePerInviteeCap := previousSettings.AffiliateRebatePerInviteeCap
+ if req.AffiliateRebatePerInviteeCap != nil {
+ affiliateRebatePerInviteeCap = *req.AffiliateRebatePerInviteeCap
+ }
+ if affiliateRebatePerInviteeCap < 0 {
+ affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault
+ }
// 通用表格配置:兼容旧客户端未传字段时保留当前值。
if req.TableDefaultPageSize <= 0 {
req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
@@ -381,6 +581,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SMTPPort = 587
}
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
+ req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions)
+ req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
+ req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
+ req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
@@ -459,7 +663,141 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
+ if req.WeChatConnectEnabled {
+ req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID)
+ req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret)
+ req.WeChatConnectOpenAppID = strings.TrimSpace(req.WeChatConnectOpenAppID)
+ req.WeChatConnectOpenAppSecret = strings.TrimSpace(req.WeChatConnectOpenAppSecret)
+ req.WeChatConnectMPAppID = strings.TrimSpace(req.WeChatConnectMPAppID)
+ req.WeChatConnectMPAppSecret = strings.TrimSpace(req.WeChatConnectMPAppSecret)
+ req.WeChatConnectMobileAppID = strings.TrimSpace(req.WeChatConnectMobileAppID)
+ req.WeChatConnectMobileAppSecret = strings.TrimSpace(req.WeChatConnectMobileAppSecret)
+ req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(req.WeChatConnectMode))
+ req.WeChatConnectScopes = strings.TrimSpace(req.WeChatConnectScopes)
+ req.WeChatConnectRedirectURL = strings.TrimSpace(req.WeChatConnectRedirectURL)
+ req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(req.WeChatConnectFrontendRedirectURL)
+ req.WeChatConnectAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectAppID, previousSettings.WeChatConnectAppID))
+ req.WeChatConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectRedirectURL, previousSettings.WeChatConnectRedirectURL))
+ req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectFrontendRedirectURL, previousSettings.WeChatConnectFrontendRedirectURL))
+ if req.WeChatConnectMode == "" {
+ req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(previousSettings.WeChatConnectMode))
+ }
+ if req.WeChatConnectScopes == "" {
+ req.WeChatConnectScopes = strings.TrimSpace(previousSettings.WeChatConnectScopes)
+ }
+
+ if req.WeChatConnectMPEnabled && req.WeChatConnectMobileEnabled {
+ response.BadRequest(c, "WeChat Official Account and Mobile App cannot be enabled at the same time")
+ return
+ }
+ if req.WeChatConnectMode != "" {
+ switch req.WeChatConnectMode {
+ case "open", "mp", "mobile":
+ default:
+ response.BadRequest(c, "WeChat mode must be open, mp, or mobile")
+ return
+ }
+ }
+ if !req.WeChatConnectOpenEnabled && !req.WeChatConnectMPEnabled && !req.WeChatConnectMobileEnabled {
+ switch req.WeChatConnectMode {
+ case "mp":
+ req.WeChatConnectMPEnabled = true
+ case "mobile":
+ req.WeChatConnectMobileEnabled = true
+ default:
+ req.WeChatConnectOpenEnabled = true
+ }
+ }
+ if req.WeChatConnectMode == "" {
+ if req.WeChatConnectMPEnabled {
+ req.WeChatConnectMode = "mp"
+ } else if req.WeChatConnectMobileEnabled {
+ req.WeChatConnectMode = "mobile"
+ } else {
+ req.WeChatConnectMode = "open"
+ }
+ }
+
+ req.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectOpenAppID, previousSettings.WeChatConnectAppID))
+ req.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMPAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMPAppID, previousSettings.WeChatConnectAppID))
+ req.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMobileAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMobileAppID, previousSettings.WeChatConnectAppID))
+
+ if req.WeChatConnectOpenAppSecret == "" {
+ req.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectOpenAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
+ }
+ if req.WeChatConnectMPAppSecret == "" {
+ req.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectMPAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
+ }
+ if req.WeChatConnectMobileAppSecret == "" {
+ req.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectMobileAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
+ }
+ if req.WeChatConnectAppSecret == "" {
+ req.WeChatConnectAppSecret = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppSecret, req.WeChatConnectMPAppSecret, req.WeChatConnectMobileAppSecret, previousSettings.WeChatConnectAppSecret))
+ }
+
+ if req.WeChatConnectOpenEnabled {
+ if req.WeChatConnectOpenAppID == "" {
+ response.BadRequest(c, "WeChat PC App ID is required when enabled")
+ return
+ }
+ if req.WeChatConnectOpenAppSecret == "" {
+ response.BadRequest(c, "WeChat PC App Secret is required when enabled")
+ return
+ }
+ }
+ if req.WeChatConnectMPEnabled {
+ if req.WeChatConnectMPAppID == "" {
+ response.BadRequest(c, "WeChat Official Account App ID is required when enabled")
+ return
+ }
+ if req.WeChatConnectMPAppSecret == "" {
+ response.BadRequest(c, "WeChat Official Account App Secret is required when enabled")
+ return
+ }
+ }
+ if req.WeChatConnectMobileEnabled {
+ if req.WeChatConnectMobileAppID == "" {
+ response.BadRequest(c, "WeChat Mobile App ID is required when enabled")
+ return
+ }
+ if req.WeChatConnectMobileAppSecret == "" {
+ response.BadRequest(c, "WeChat Mobile App Secret is required when enabled")
+ return
+ }
+ }
+
+ if req.WeChatConnectScopes == "" {
+ if req.WeChatConnectMPEnabled {
+ req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode("mp")
+ } else {
+ req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode)
+ }
+ }
+ if req.WeChatConnectOpenEnabled || req.WeChatConnectMPEnabled {
+ if req.WeChatConnectRedirectURL == "" {
+ response.BadRequest(c, "WeChat Redirect URL is required when web oauth is enabled")
+ return
+ }
+ if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil {
+ response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL")
+ return
+ }
+ if req.WeChatConnectFrontendRedirectURL == "" {
+ req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback"
+ }
+ if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil {
+ response.BadRequest(c, "WeChat Frontend Redirect URL is invalid")
+ return
+ }
+ }
+ }
+
// Generic OIDC 参数验证
+ oidcUsePKCE, oidcValidateIDToken, err := h.settingService.OIDCSecurityWriteDefaults(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
if req.OIDCConnectEnabled {
req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName)
req.OIDCConnectClientID = strings.TrimSpace(req.OIDCConnectClientID)
@@ -478,10 +816,35 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(req.OIDCConnectUserInfoEmailPath)
req.OIDCConnectUserInfoIDPath = strings.TrimSpace(req.OIDCConnectUserInfoIDPath)
req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(req.OIDCConnectUserInfoUsernamePath)
-
- if req.OIDCConnectProviderName == "" {
- req.OIDCConnectProviderName = "OIDC"
+ req.OIDCConnectProviderName = strings.TrimSpace(firstNonEmpty(req.OIDCConnectProviderName, previousSettings.OIDCConnectProviderName, "OIDC"))
+ req.OIDCConnectClientID = strings.TrimSpace(firstNonEmpty(req.OIDCConnectClientID, previousSettings.OIDCConnectClientID))
+ req.OIDCConnectIssuerURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectIssuerURL, previousSettings.OIDCConnectIssuerURL))
+ req.OIDCConnectDiscoveryURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectDiscoveryURL, previousSettings.OIDCConnectDiscoveryURL))
+ req.OIDCConnectAuthorizeURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAuthorizeURL, previousSettings.OIDCConnectAuthorizeURL))
+ req.OIDCConnectTokenURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenURL, previousSettings.OIDCConnectTokenURL))
+ req.OIDCConnectUserInfoURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoURL, previousSettings.OIDCConnectUserInfoURL))
+ req.OIDCConnectJWKSURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectJWKSURL, previousSettings.OIDCConnectJWKSURL))
+ req.OIDCConnectScopes = strings.TrimSpace(firstNonEmpty(req.OIDCConnectScopes, previousSettings.OIDCConnectScopes, "openid email profile"))
+ req.OIDCConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectRedirectURL, previousSettings.OIDCConnectRedirectURL))
+ req.OIDCConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectFrontendRedirectURL, previousSettings.OIDCConnectFrontendRedirectURL, "/auth/oidc/callback"))
+ req.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenAuthMethod, previousSettings.OIDCConnectTokenAuthMethod, "client_secret_post")))
+ req.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAllowedSigningAlgs, previousSettings.OIDCConnectAllowedSigningAlgs, "RS256,ES256,PS256"))
+ req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoEmailPath, previousSettings.OIDCConnectUserInfoEmailPath))
+ req.OIDCConnectUserInfoIDPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoIDPath, previousSettings.OIDCConnectUserInfoIDPath))
+ req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoUsernamePath, previousSettings.OIDCConnectUserInfoUsernamePath))
+ if req.OIDCConnectUsePKCE != nil {
+ oidcUsePKCE = *req.OIDCConnectUsePKCE
}
+ if req.OIDCConnectValidateIDToken != nil {
+ oidcValidateIDToken = *req.OIDCConnectValidateIDToken
+ }
+ if req.OIDCConnectClockSkewSeconds == 0 {
+ req.OIDCConnectClockSkewSeconds = previousSettings.OIDCConnectClockSkewSeconds
+ if req.OIDCConnectClockSkewSeconds == 0 {
+ req.OIDCConnectClockSkewSeconds = 120
+ }
+ }
+
if req.OIDCConnectClientID == "" {
response.BadRequest(c, "OIDC Client ID is required when enabled")
return
@@ -544,19 +907,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none")
return
}
- if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE {
- response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none")
- return
- }
if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 {
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
return
}
- if req.OIDCConnectValidateIDToken {
- if req.OIDCConnectAllowedSigningAlgs == "" {
- response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
- return
- }
+ if oidcValidateIDToken && req.OIDCConnectAllowedSigningAlgs == "" {
+ response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
+ return
}
if req.OIDCConnectJWKSURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil {
@@ -805,6 +1162,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
+ WeChatConnectEnabled: req.WeChatConnectEnabled,
+ WeChatConnectAppID: req.WeChatConnectAppID,
+ WeChatConnectAppSecret: req.WeChatConnectAppSecret,
+ WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
+ WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
+ WeChatConnectMPAppID: req.WeChatConnectMPAppID,
+ WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
+ WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
+ WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
+ WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
+ WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
+ WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
+ WeChatConnectMode: req.WeChatConnectMode,
+ WeChatConnectScopes: req.WeChatConnectScopes,
+ WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
OIDCConnectEnabled: req.OIDCConnectEnabled,
OIDCConnectProviderName: req.OIDCConnectProviderName,
OIDCConnectClientID: req.OIDCConnectClientID,
@@ -819,8 +1192,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
- OIDCConnectUsePKCE: req.OIDCConnectUsePKCE,
- OIDCConnectValidateIDToken: req.OIDCConnectValidateIDToken,
+ OIDCConnectUsePKCE: oidcUsePKCE,
+ OIDCConnectValidateIDToken: oidcValidateIDToken,
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
@@ -843,6 +1216,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
+ AffiliateRebateRate: affiliateRebateRate,
+ AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
+ AffiliateRebateDurationDays: affiliateRebateDurationDays,
+ AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
+ DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
@@ -897,6 +1275,42 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableCCHSigning
}(),
+ EnableAnthropicCacheTTL1hInjection: func() bool {
+ if req.EnableAnthropicCacheTTL1hInjection != nil {
+ return *req.EnableAnthropicCacheTTL1hInjection
+ }
+ return previousSettings.EnableAnthropicCacheTTL1hInjection
+ }(),
+ PaymentVisibleMethodAlipaySource: func() string {
+ if req.PaymentVisibleMethodAlipaySource != nil {
+ return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
+ }
+ return previousSettings.PaymentVisibleMethodAlipaySource
+ }(),
+ PaymentVisibleMethodWxpaySource: func() string {
+ if req.PaymentVisibleMethodWxpaySource != nil {
+ return strings.TrimSpace(*req.PaymentVisibleMethodWxpaySource)
+ }
+ return previousSettings.PaymentVisibleMethodWxpaySource
+ }(),
+ PaymentVisibleMethodAlipayEnabled: func() bool {
+ if req.PaymentVisibleMethodAlipayEnabled != nil {
+ return *req.PaymentVisibleMethodAlipayEnabled
+ }
+ return previousSettings.PaymentVisibleMethodAlipayEnabled
+ }(),
+ PaymentVisibleMethodWxpayEnabled: func() bool {
+ if req.PaymentVisibleMethodWxpayEnabled != nil {
+ return *req.PaymentVisibleMethodWxpayEnabled
+ }
+ return previousSettings.PaymentVisibleMethodWxpayEnabled
+ }(),
+ OpenAIAdvancedSchedulerEnabled: func() bool {
+ if req.OpenAIAdvancedSchedulerEnabled != nil {
+ return *req.OpenAIAdvancedSchedulerEnabled
+ }
+ return previousSettings.OpenAIAdvancedSchedulerEnabled
+ }(),
BalanceLowNotifyEnabled: func() bool {
if req.BalanceLowNotifyEnabled != nil {
return *req.BalanceLowNotifyEnabled
@@ -927,13 +1341,76 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.AccountQuotaNotifyEmails
}(),
+ ChannelMonitorEnabled: func() bool {
+ if req.ChannelMonitorEnabled != nil {
+ return *req.ChannelMonitorEnabled
+ }
+ return previousSettings.ChannelMonitorEnabled
+ }(),
+ ChannelMonitorDefaultIntervalSeconds: func() int {
+ if req.ChannelMonitorDefaultIntervalSeconds != nil {
+ return *req.ChannelMonitorDefaultIntervalSeconds
+ }
+ return previousSettings.ChannelMonitorDefaultIntervalSeconds
+ }(),
+ AvailableChannelsEnabled: func() bool {
+ if req.AvailableChannelsEnabled != nil {
+ return *req.AvailableChannelsEnabled
+ }
+ return previousSettings.AvailableChannelsEnabled
+ }(),
+ AffiliateEnabled: func() bool {
+ if req.AffiliateEnabled != nil {
+ return *req.AffiliateEnabled
+ }
+ return previousSettings.AffiliateEnabled
+ }(),
}
- if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
+ authSourceDefaults := &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
+ },
+ LinuxDo: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
+ },
+ OIDC: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
+ },
+ WeChat: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
+ },
+ ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
+ }
+ if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
response.ErrorFrom(c, err)
return
}
+ // Update OpenAI fast policy (stored under dedicated key, only when provided).
+ if req.OpenAIFastPolicySettings != nil {
+ if err := h.settingService.SetOpenAIFastPolicySettings(c.Request.Context(), openaiFastPolicySettingsFromDTO(req.OpenAIFastPolicySettings)); err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+ }
+
// Update payment configuration (integrated into system settings).
// Skip if no payment fields were provided (prevents accidental wipe).
if h.paymentConfigService != nil && hasPaymentFields(req) {
@@ -969,7 +1446,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
- h.auditSettingsUpdate(c, previousSettings, settings, req)
+ h.auditSettingsUpdate(c, previousSettings, settings, previousAuthSourceDefaults, authSourceDefaults, req)
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
@@ -977,6 +1454,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
for _, sub := range updatedSettings.DefaultSubscriptions {
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
@@ -994,113 +1476,153 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
updatedPaymentCfg = &service.PaymentConfig{}
}
- response.Success(c, dto.SystemSettings{
- RegistrationEnabled: updatedSettings.RegistrationEnabled,
- EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
- RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
- PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
- PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
- FrontendURL: updatedSettings.FrontendURL,
- InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
- TotpEnabled: updatedSettings.TotpEnabled,
- TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
- SMTPHost: updatedSettings.SMTPHost,
- SMTPPort: updatedSettings.SMTPPort,
- SMTPUsername: updatedSettings.SMTPUsername,
- SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
- SMTPFrom: updatedSettings.SMTPFrom,
- SMTPFromName: updatedSettings.SMTPFromName,
- SMTPUseTLS: updatedSettings.SMTPUseTLS,
- TurnstileEnabled: updatedSettings.TurnstileEnabled,
- TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
- TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
- LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
- LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
- LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
- LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
- OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled,
- OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName,
- OIDCConnectClientID: updatedSettings.OIDCConnectClientID,
- OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured,
- OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL,
- OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL,
- OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL,
- OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL,
- OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL,
- OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL,
- OIDCConnectScopes: updatedSettings.OIDCConnectScopes,
- OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL,
- OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL,
- OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod,
- OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE,
- OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken,
- OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs,
- OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds,
- OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified,
- OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
- OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
- OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
- SiteName: updatedSettings.SiteName,
- SiteLogo: updatedSettings.SiteLogo,
- SiteSubtitle: updatedSettings.SiteSubtitle,
- APIBaseURL: updatedSettings.APIBaseURL,
- ContactInfo: updatedSettings.ContactInfo,
- DocURL: updatedSettings.DocURL,
- HomeContent: updatedSettings.HomeContent,
- HideCcsImportButton: updatedSettings.HideCcsImportButton,
- PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
- PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
- TableDefaultPageSize: updatedSettings.TableDefaultPageSize,
- TablePageSizeOptions: updatedSettings.TablePageSizeOptions,
- CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
- CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
- DefaultConcurrency: updatedSettings.DefaultConcurrency,
- DefaultBalance: updatedSettings.DefaultBalance,
- DefaultSubscriptions: updatedDefaultSubscriptions,
- EnableModelFallback: updatedSettings.EnableModelFallback,
- FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
- FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
- FallbackModelGemini: updatedSettings.FallbackModelGemini,
- FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
- EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
- IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
- OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
- OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
- OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
- OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
- MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
- MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion,
- AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
- BackendModeEnabled: updatedSettings.BackendModeEnabled,
- EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
- EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
- EnableCCHSigning: updatedSettings.EnableCCHSigning,
- BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
- BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
- BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
- AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
- AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
- PaymentEnabled: updatedPaymentCfg.Enabled,
- PaymentMinAmount: updatedPaymentCfg.MinAmount,
- PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
- PaymentDailyLimit: updatedPaymentCfg.DailyLimit,
- PaymentOrderTimeoutMin: updatedPaymentCfg.OrderTimeoutMin,
- PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders,
- PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes,
- PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled,
- PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier,
- PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate,
- PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy,
- PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix,
- PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix,
- PaymentHelpImageURL: updatedPaymentCfg.HelpImageURL,
- PaymentHelpText: updatedPaymentCfg.HelpText,
- PaymentCancelRateLimitEnabled: updatedPaymentCfg.CancelRateLimitEnabled,
- PaymentCancelRateLimitMax: updatedPaymentCfg.CancelRateLimitMax,
- PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
- PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
- PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
- })
+ payload := dto.SystemSettings{
+ RegistrationEnabled: updatedSettings.RegistrationEnabled,
+ EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
+ RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
+ PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
+ PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
+ FrontendURL: updatedSettings.FrontendURL,
+ InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
+ TotpEnabled: updatedSettings.TotpEnabled,
+ TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
+ SMTPHost: updatedSettings.SMTPHost,
+ SMTPPort: updatedSettings.SMTPPort,
+ SMTPUsername: updatedSettings.SMTPUsername,
+ SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
+ SMTPFrom: updatedSettings.SMTPFrom,
+ SMTPFromName: updatedSettings.SMTPFromName,
+ SMTPUseTLS: updatedSettings.SMTPUseTLS,
+ TurnstileEnabled: updatedSettings.TurnstileEnabled,
+ TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
+ TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
+ LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
+ LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
+ LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
+ LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
+ WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled,
+ WeChatConnectAppID: updatedSettings.WeChatConnectAppID,
+ WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured,
+ WeChatConnectOpenAppID: updatedSettings.WeChatConnectOpenAppID,
+ WeChatConnectOpenAppSecretConfigured: updatedSettings.WeChatConnectOpenAppSecretConfigured,
+ WeChatConnectMPAppID: updatedSettings.WeChatConnectMPAppID,
+ WeChatConnectMPAppSecretConfigured: updatedSettings.WeChatConnectMPAppSecretConfigured,
+ WeChatConnectMobileAppID: updatedSettings.WeChatConnectMobileAppID,
+ WeChatConnectMobileAppSecretConfigured: updatedSettings.WeChatConnectMobileAppSecretConfigured,
+ WeChatConnectOpenEnabled: updatedSettings.WeChatConnectOpenEnabled,
+ WeChatConnectMPEnabled: updatedSettings.WeChatConnectMPEnabled,
+ WeChatConnectMobileEnabled: updatedSettings.WeChatConnectMobileEnabled,
+ WeChatConnectMode: updatedSettings.WeChatConnectMode,
+ WeChatConnectScopes: updatedSettings.WeChatConnectScopes,
+ WeChatConnectRedirectURL: updatedSettings.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: updatedSettings.WeChatConnectFrontendRedirectURL,
+ OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled,
+ OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName,
+ OIDCConnectClientID: updatedSettings.OIDCConnectClientID,
+ OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured,
+ OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL,
+ OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL,
+ OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL,
+ OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL,
+ OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL,
+ OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL,
+ OIDCConnectScopes: updatedSettings.OIDCConnectScopes,
+ OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL,
+ OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL,
+ OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod,
+ OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE,
+ OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken,
+ OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs,
+ OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds,
+ OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified,
+ OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
+ OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
+ OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
+ SiteName: updatedSettings.SiteName,
+ SiteLogo: updatedSettings.SiteLogo,
+ SiteSubtitle: updatedSettings.SiteSubtitle,
+ APIBaseURL: updatedSettings.APIBaseURL,
+ ContactInfo: updatedSettings.ContactInfo,
+ DocURL: updatedSettings.DocURL,
+ HomeContent: updatedSettings.HomeContent,
+ HideCcsImportButton: updatedSettings.HideCcsImportButton,
+ PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
+ PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
+ TableDefaultPageSize: updatedSettings.TableDefaultPageSize,
+ TablePageSizeOptions: updatedSettings.TablePageSizeOptions,
+ CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
+ CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
+ DefaultConcurrency: updatedSettings.DefaultConcurrency,
+ DefaultBalance: updatedSettings.DefaultBalance,
+ AffiliateRebateRate: updatedSettings.AffiliateRebateRate,
+ AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours,
+ AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays,
+ AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap,
+ DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
+ DefaultSubscriptions: updatedDefaultSubscriptions,
+ EnableModelFallback: updatedSettings.EnableModelFallback,
+ FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
+ FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
+ FallbackModelGemini: updatedSettings.FallbackModelGemini,
+ FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
+ EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
+ IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
+ OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
+ OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
+ OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
+ OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
+ MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
+ MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion,
+ AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
+ BackendModeEnabled: updatedSettings.BackendModeEnabled,
+ EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
+ EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
+ EnableCCHSigning: updatedSettings.EnableCCHSigning,
+ EnableAnthropicCacheTTL1hInjection: updatedSettings.EnableAnthropicCacheTTL1hInjection,
+ PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
+ PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
+ PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
+ PaymentVisibleMethodWxpayEnabled: updatedSettings.PaymentVisibleMethodWxpayEnabled,
+ OpenAIAdvancedSchedulerEnabled: updatedSettings.OpenAIAdvancedSchedulerEnabled,
+ BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
+ BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
+ BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
+ AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
+ AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
+ PaymentEnabled: updatedPaymentCfg.Enabled,
+ PaymentMinAmount: updatedPaymentCfg.MinAmount,
+ PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
+ PaymentDailyLimit: updatedPaymentCfg.DailyLimit,
+ PaymentOrderTimeoutMin: updatedPaymentCfg.OrderTimeoutMin,
+ PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders,
+ PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes,
+ PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled,
+ PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier,
+ PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate,
+ PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy,
+ PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix,
+ PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix,
+ PaymentHelpImageURL: updatedPaymentCfg.HelpImageURL,
+ PaymentHelpText: updatedPaymentCfg.HelpText,
+ PaymentCancelRateLimitEnabled: updatedPaymentCfg.CancelRateLimitEnabled,
+ PaymentCancelRateLimitMax: updatedPaymentCfg.CancelRateLimitMax,
+ PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
+ PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
+ PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
+
+ ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled,
+ ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
+
+ AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
+
+ AffiliateEnabled: updatedSettings.AffiliateEnabled,
+ }
+ if fastPolicy, err := h.settingService.GetOpenAIFastPolicySettings(c.Request.Context()); err != nil {
+ slog.Error("openai_fast_policy_settings_get_failed", "error", err)
+ } else if fastPolicy != nil {
+ payload.OpenAIFastPolicySettings = openaiFastPolicySettingsToDTO(fastPolicy)
+ }
+ response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
// hasPaymentFields returns true if any payment-related field was explicitly provided.
@@ -1117,12 +1639,12 @@ func hasPaymentFields(req UpdateSettingsRequest) bool {
req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil
}
-func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
+func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) {
if before == nil || after == nil {
return
}
- changed := diffSettings(before, after, req)
+ changed := diffSettings(before, after, beforeAuthSourceDefaults, afterAuthSourceDefaults, req)
if len(changed) == 0 {
return
}
@@ -1137,7 +1659,7 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys
)
}
-func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
+func diffSettings(before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) []string {
changed := make([]string, 0, 20)
if before.RegistrationEnabled != after.RegistrationEnabled {
changed = append(changed, "registration_enabled")
@@ -1205,6 +1727,54 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
changed = append(changed, "linuxdo_connect_redirect_url")
}
+ if before.WeChatConnectEnabled != after.WeChatConnectEnabled {
+ changed = append(changed, "wechat_connect_enabled")
+ }
+ if before.WeChatConnectAppID != after.WeChatConnectAppID {
+ changed = append(changed, "wechat_connect_app_id")
+ }
+ if req.WeChatConnectAppSecret != "" {
+ changed = append(changed, "wechat_connect_app_secret")
+ }
+ if before.WeChatConnectOpenAppID != after.WeChatConnectOpenAppID {
+ changed = append(changed, "wechat_connect_open_app_id")
+ }
+ if req.WeChatConnectOpenAppSecret != "" {
+ changed = append(changed, "wechat_connect_open_app_secret")
+ }
+ if before.WeChatConnectMPAppID != after.WeChatConnectMPAppID {
+ changed = append(changed, "wechat_connect_mp_app_id")
+ }
+ if req.WeChatConnectMPAppSecret != "" {
+ changed = append(changed, "wechat_connect_mp_app_secret")
+ }
+ if before.WeChatConnectMobileAppID != after.WeChatConnectMobileAppID {
+ changed = append(changed, "wechat_connect_mobile_app_id")
+ }
+ if req.WeChatConnectMobileAppSecret != "" {
+ changed = append(changed, "wechat_connect_mobile_app_secret")
+ }
+ if before.WeChatConnectOpenEnabled != after.WeChatConnectOpenEnabled {
+ changed = append(changed, "wechat_connect_open_enabled")
+ }
+ if before.WeChatConnectMPEnabled != after.WeChatConnectMPEnabled {
+ changed = append(changed, "wechat_connect_mp_enabled")
+ }
+ if before.WeChatConnectMobileEnabled != after.WeChatConnectMobileEnabled {
+ changed = append(changed, "wechat_connect_mobile_enabled")
+ }
+ if before.WeChatConnectMode != after.WeChatConnectMode {
+ changed = append(changed, "wechat_connect_mode")
+ }
+ if before.WeChatConnectScopes != after.WeChatConnectScopes {
+ changed = append(changed, "wechat_connect_scopes")
+ }
+ if before.WeChatConnectRedirectURL != after.WeChatConnectRedirectURL {
+ changed = append(changed, "wechat_connect_redirect_url")
+ }
+ if before.WeChatConnectFrontendRedirectURL != after.WeChatConnectFrontendRedirectURL {
+ changed = append(changed, "wechat_connect_frontend_redirect_url")
+ }
if before.OIDCConnectEnabled != after.OIDCConnectEnabled {
changed = append(changed, "oidc_connect_enabled")
}
@@ -1301,6 +1871,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
+ if before.AffiliateRebateRate != after.AffiliateRebateRate {
+ changed = append(changed, "affiliate_rebate_rate")
+ }
+ if before.AffiliateRebateFreezeHours != after.AffiliateRebateFreezeHours {
+ changed = append(changed, "affiliate_rebate_freeze_hours")
+ }
+ if before.AffiliateRebateDurationDays != after.AffiliateRebateDurationDays {
+ changed = append(changed, "affiliate_rebate_duration_days")
+ }
+ if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap {
+ changed = append(changed, "affiliate_rebate_per_invitee_cap")
+ }
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions")
}
@@ -1376,6 +1958,24 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing")
}
+ if before.EnableAnthropicCacheTTL1hInjection != after.EnableAnthropicCacheTTL1hInjection {
+ changed = append(changed, "enable_anthropic_cache_ttl_1h_injection")
+ }
+ if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
+ changed = append(changed, "payment_visible_method_alipay_source")
+ }
+ if before.PaymentVisibleMethodWxpaySource != after.PaymentVisibleMethodWxpaySource {
+ changed = append(changed, "payment_visible_method_wxpay_source")
+ }
+ if before.PaymentVisibleMethodAlipayEnabled != after.PaymentVisibleMethodAlipayEnabled {
+ changed = append(changed, "payment_visible_method_alipay_enabled")
+ }
+ if before.PaymentVisibleMethodWxpayEnabled != after.PaymentVisibleMethodWxpayEnabled {
+ changed = append(changed, "payment_visible_method_wxpay_enabled")
+ }
+ if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled {
+ changed = append(changed, "openai_advanced_scheduler_enabled")
+ }
// Balance & quota notification
if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
changed = append(changed, "balance_low_notify_enabled")
@@ -1392,6 +1992,62 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
changed = append(changed, "account_quota_notify_emails")
}
+ if before.ChannelMonitorEnabled != after.ChannelMonitorEnabled {
+ changed = append(changed, "channel_monitor_enabled")
+ }
+ if before.ChannelMonitorDefaultIntervalSeconds != after.ChannelMonitorDefaultIntervalSeconds {
+ changed = append(changed, "channel_monitor_default_interval_seconds")
+ }
+ if before.AvailableChannelsEnabled != after.AvailableChannelsEnabled {
+ changed = append(changed, "available_channels_enabled")
+ }
+ if before.AffiliateEnabled != after.AffiliateEnabled {
+ changed = append(changed, "affiliate_enabled")
+ }
+ changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
+ return changed
+}
+
+func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSourceDefaultSettings, after *service.AuthSourceDefaultSettings) []string {
+ if before == nil {
+ before = &service.AuthSourceDefaultSettings{}
+ }
+ if after == nil {
+ after = &service.AuthSourceDefaultSettings{}
+ }
+
+ type providerDefaultGrantField struct {
+ name string
+ before service.ProviderDefaultGrantSettings
+ after service.ProviderDefaultGrantSettings
+ }
+
+ fields := []providerDefaultGrantField{
+ {name: "email", before: before.Email, after: after.Email},
+ {name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo},
+ {name: "oidc", before: before.OIDC, after: after.OIDC},
+ {name: "wechat", before: before.WeChat, after: after.WeChat},
+ }
+ for _, field := range fields {
+ if field.before.Balance != field.after.Balance {
+ changed = append(changed, "auth_source_default_"+field.name+"_balance")
+ }
+ if field.before.Concurrency != field.after.Concurrency {
+ changed = append(changed, "auth_source_default_"+field.name+"_concurrency")
+ }
+ if !equalDefaultSubscriptions(field.before.Subscriptions, field.after.Subscriptions) {
+ changed = append(changed, "auth_source_default_"+field.name+"_subscriptions")
+ }
+ if field.before.GrantOnSignup != field.after.GrantOnSignup {
+ changed = append(changed, "auth_source_default_"+field.name+"_grant_on_signup")
+ }
+ if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind {
+ changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind")
+ }
+ }
+ if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup {
+ changed = append(changed, "force_email_on_third_party_signup")
+ }
return changed
}
@@ -1412,6 +2068,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
return normalized
}
+func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting {
+ if input == nil {
+ return nil
+ }
+ normalized := normalizeDefaultSubscriptions(*input)
+ return &normalized
+}
+
+func float64ValueOrDefault(value *float64, fallback float64) float64 {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func intValueOrDefault(value *int, fallback int) int {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func boolValueOrDefault(value *bool, fallback bool) bool {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting {
+ if input == nil {
+ return fallback
+ }
+ result := make([]service.DefaultSubscriptionSetting, 0, len(*input))
+ for _, item := range *input {
+ result = append(result, service.DefaultSubscriptionSetting{
+ GroupID: item.GroupID,
+ ValidityDays: item.ValidityDays,
+ })
+ }
+ return result
+}
+
+func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
+ data := make(map[string]any)
+ raw, err := json.Marshal(settings)
+ if err == nil {
+ _ = json.Unmarshal(raw, &data)
+ }
+ if authSourceDefaults == nil {
+ authSourceDefaults = &service.AuthSourceDefaultSettings{}
+ }
+
+ data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance
+ data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency
+ data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions
+ data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup
+ data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind
+ data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance
+ data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency
+ data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
+ data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
+ data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
+ data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
+ data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
+ data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
+ data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup
+ data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind
+ data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance
+ data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency
+ data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
+ data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
+ data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
+ data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
+
+ return data
+}
+
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false
diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
new file mode 100644
index 00000000..085fd2ca
--- /dev/null
+++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
@@ -0,0 +1,508 @@
+package admin
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type settingHandlerRepoStub struct {
+ values map[string]string
+ lastUpdates map[string]string
+}
+
+func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ if s.values != nil {
+ if value, ok := s.values[key]; ok {
+ return value, nil
+ }
+ }
+ return "", nil
+}
+
+func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ s.lastUpdates = make(map[string]string, len(settings))
+ for key, value := range settings {
+ s.lastUpdates[key] = value
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+type failingAuthSourceSettingsRepoStub struct {
+ values map[string]string
+ err error
+}
+
+func (s *failingAuthSourceSettingsRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *failingAuthSourceSettingsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *failingAuthSourceSettingsRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *failingAuthSourceSettingsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *failingAuthSourceSettingsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ if _, ok := settings[service.SettingKeyAuthSourceDefaultEmailBalance]; ok {
+ return s.err
+ }
+ for key, value := range settings {
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *failingAuthSourceSettingsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *failingAuthSourceSettingsRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
+
+ handler.GetSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, 9.5, data["auth_source_default_email_balance"])
+ require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
+ require.Equal(t, true, data["force_email_on_third_party_signup"])
+
+ subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any)
+ require.True(t, ok)
+ require.Len(t, subscriptions, 1)
+}
+
+func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "false",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "registration_enabled": true,
+ "promo_code_enabled": true,
+ "auth_source_default_email_balance": 12.75,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
+ require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency])
+ require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions])
+ require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, 12.75, data["auth_source_default_email_balance"])
+ require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
+ require.Equal(t, true, data["force_email_on_third_party_signup"])
+}
+
+func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "payment_visible_method_alipay_source": "easypay",
+ "payment_visible_method_wxpay_source": "wxpay",
+ "payment_visible_method_alipay_enabled": true,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": true,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, repo.values[service.SettingPaymentVisibleMethodAlipaySource])
+ require.Equal(t, service.VisibleMethodSourceOfficialWechat, repo.values[service.SettingPaymentVisibleMethodWxpaySource])
+ require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled])
+ require.Equal(t, "false", repo.values[service.SettingPaymentVisibleMethodWxpayEnabled])
+ require.Equal(t, "true", repo.values["openai_advanced_scheduler_enabled"])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, data["payment_visible_method_alipay_source"])
+ require.Equal(t, service.VisibleMethodSourceOfficialWechat, data["payment_visible_method_wxpay_source"])
+ require.Equal(t, true, data["payment_visible_method_alipay_enabled"])
+ require.Equal(t, false, data["payment_visible_method_wxpay_enabled"])
+ require.Equal(t, true, data["openai_advanced_scheduler_enabled"])
+}
+
+func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodSource(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingPaymentVisibleMethodAlipayEnabled: "true",
+ service.SettingPaymentVisibleMethodAlipaySource: "",
+ service.SettingPaymentVisibleMethodWxpayEnabled: "false",
+ service.SettingPaymentVisibleMethodWxpaySource: "",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": false,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "", repo.values[service.SettingPaymentVisibleMethodAlipaySource])
+ require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled])
+}
+
+func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFlags(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyOIDCConnectEnabled: "true",
+ service.SettingKeyOIDCConnectProviderName: "OIDC",
+ service.SettingKeyOIDCConnectClientID: "oidc-client",
+ service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
+ service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
+ service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
+ service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
+ service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
+ service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
+ service.SettingKeyOIDCConnectScopes: "openid email profile",
+ service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
+ service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
+ service.SettingKeyOIDCConnectUsePKCE: "true",
+ service.SettingKeyOIDCConnectValidateIDToken: "true",
+ service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
+ service.SettingKeyOIDCConnectClockSkewSeconds: "120",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "oidc_connect_enabled": true,
+ "oidc_connect_use_pkce": false,
+ "oidc_connect_validate_id_token": false,
+ "oidc_connect_allowed_signing_algs": "",
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
+ require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, false, data["oidc_connect_use_pkce"])
+ require.Equal(t, false, data["oidc_connect_validate_id_token"])
+}
+
+func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaultsOnLegacyUpgrade(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyOIDCConnectEnabled: "true",
+ service.SettingKeyOIDCConnectProviderName: "OIDC",
+ service.SettingKeyOIDCConnectClientID: "oidc-client",
+ service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
+ service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
+ service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
+ service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
+ service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
+ service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
+ service.SettingKeyOIDCConnectScopes: "openid email profile",
+ service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
+ service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
+ service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
+ service.SettingKeyOIDCConnectClockSkewSeconds: "120",
+ service.SettingKeyOIDCConnectRequireEmailVerified: "false",
+ service.SettingKeyOIDCConnectUserInfoEmailPath: "",
+ service.SettingKeyOIDCConnectUserInfoIDPath: "",
+ service.SettingKeyOIDCConnectUserInfoUsernamePath: "",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{
+ Default: config.DefaultConfig{UserConcurrency: 5},
+ OIDC: config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "OIDC",
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/auth",
+ TokenURL: "https://issuer.example.com/token",
+ UserInfoURL: "https://issuer.example.com/userinfo",
+ JWKSURL: "https://issuer.example.com/jwks",
+ Scopes: "openid email profile",
+ RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ },
+ })
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "oidc_connect_enabled": true,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
+ require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
+}
+
+func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "payment_visible_method_alipay_source": "bogus",
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource)
+}
+
+func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAuthSourceDefaultsFail(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &failingAuthSourceSettingsRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "false",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ },
+ err: errors.New("write auth source defaults failed"),
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "registration_enabled": true,
+ "promo_code_enabled": true,
+ "auth_source_default_email_balance": 12.75,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+ require.Equal(t, "false", repo.values[service.SettingKeyRegistrationEnabled])
+ require.Equal(t, "9.5", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
+}
+
+func TestDiffSettings_IncludesAuthSourceDefaultsAndForceEmail(t *testing.T) {
+ changed := diffSettings(
+ &service.SystemSettings{},
+ &service.SystemSettings{},
+ &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: 0,
+ Concurrency: 5,
+ Subscriptions: nil,
+ GrantOnSignup: true,
+ GrantOnFirstBind: false,
+ },
+ ForceEmailOnThirdPartySignup: false,
+ },
+ &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: 12.5,
+ Concurrency: 7,
+ Subscriptions: []service.DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 30}},
+ GrantOnSignup: false,
+ GrantOnFirstBind: true,
+ },
+ ForceEmailOnThirdPartySignup: true,
+ },
+ UpdateSettingsRequest{},
+ )
+
+ require.Contains(t, changed, "auth_source_default_email_balance")
+ require.Contains(t, changed, "auth_source_default_email_concurrency")
+ require.Contains(t, changed, "auth_source_default_email_subscriptions")
+ require.Contains(t, changed, "auth_source_default_email_grant_on_signup")
+ require.Contains(t, changed, "auth_source_default_email_grant_on_first_bind")
+ require.Contains(t, changed, "force_email_on_third_party_signup")
+}
diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go
index 1453bd07..3d80107f 100644
--- a/backend/internal/handler/admin/user_handler.go
+++ b/backend/internal/handler/admin/user_handler.go
@@ -40,6 +40,7 @@ type CreateUserRequest struct {
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
+ RPMLimit int `json:"rpm_limit"`
AllowedGroups []int64 `json:"allowed_groups"`
}
@@ -52,6 +53,7 @@ type UpdateUserRequest struct {
Notes *string `json:"notes"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
+ RPMLimit *int `json:"rpm_limit"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
@@ -66,6 +68,22 @@ type UpdateBalanceRequest struct {
Notes string `json:"notes"`
}
+type BindUserAuthIdentityRequest struct {
+ ProviderType string `json:"provider_type"`
+ ProviderKey string `json:"provider_key"`
+ ProviderSubject string `json:"provider_subject"`
+ Issuer *string `json:"issuer"`
+ Metadata map[string]any `json:"metadata"`
+ Channel *BindUserAuthIdentityChannelRequest `json:"channel"`
+}
+
+type BindUserAuthIdentityChannelRequest struct {
+ Channel string `json:"channel"`
+ ChannelAppID string `json:"channel_app_id"`
+ ChannelSubject string `json:"channel_subject"`
+ Metadata map[string]any `json:"metadata"`
+}
+
// List handles listing all users with pagination
// GET /api/v1/admin/users
// Query params:
@@ -172,6 +190,45 @@ func (h *UserHandler) GetByID(c *gin.Context) {
response.Success(c, dto.UserFromServiceAdmin(user))
}
+// BindAuthIdentity manually binds a canonical auth identity to a user.
+// POST /api/v1/admin/users/:id/auth-identities
+func (h *UserHandler) BindAuthIdentity(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid user ID")
+ return
+ }
+
+ var req BindUserAuthIdentityRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ input := service.AdminBindAuthIdentityInput{
+ ProviderType: req.ProviderType,
+ ProviderKey: req.ProviderKey,
+ ProviderSubject: req.ProviderSubject,
+ Issuer: req.Issuer,
+ Metadata: req.Metadata,
+ }
+ if req.Channel != nil {
+ input.Channel = &service.AdminBindAuthIdentityChannelInput{
+ Channel: req.Channel.Channel,
+ ChannelAppID: req.Channel.ChannelAppID,
+ ChannelSubject: req.Channel.ChannelSubject,
+ Metadata: req.Channel.Metadata,
+ }
+ }
+
+ result, err := h.adminService.BindUserAuthIdentity(c.Request.Context(), userID, input)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
+
// Create handles creating a new user
// POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) {
@@ -188,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) {
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
+ RPMLimit: req.RPMLimit,
AllowedGroups: req.AllowedGroups,
})
if err != nil {
@@ -221,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) {
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
+ RPMLimit: req.RPMLimit,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
@@ -400,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) {
"migrated_keys": result.MigratedKeys,
})
}
+
+// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量
+// GET /api/v1/admin/users/:id/rpm-status
+func (h *UserHandler) GetUserRPMStatus(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid user ID")
+ return
+ }
+
+ status, err := h.adminService.GetUserRPMStatus(c.Request.Context(), userID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, status)
+}
diff --git a/backend/internal/handler/admin/user_handler_activity_test.go b/backend/internal/handler/admin/user_handler_activity_test.go
new file mode 100644
index 00000000..bfba2408
--- /dev/null
+++ b/backend/internal/handler/admin/user_handler_activity_test.go
@@ -0,0 +1,114 @@
+//go:build unit
+
+package admin
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ lastActiveAt := lastLoginAt.Add(30 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(90 * time.Minute)
+
+ adminSvc := newStubAdminService()
+ adminSvc.users = []service.User{
+ {
+ ID: 7,
+ Email: "activity@example.com",
+ Username: "activity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
+ CreatedAt: lastLoginAt.Add(-24 * time.Hour),
+ UpdatedAt: lastLoginAt,
+ },
+ }
+ handler := NewUserHandler(adminSvc, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/users?sort_by=last_used_at&sort_order=asc&search=activity",
+ nil,
+ )
+
+ handler.List(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, "last_used_at", adminSvc.lastListUsers.sortBy)
+ require.Equal(t, "asc", adminSvc.lastListUsers.sortOrder)
+ require.Equal(t, "activity", adminSvc.lastListUsers.filters.Search)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Items []struct {
+ LastActiveAt *time.Time `json:"last_active_at"`
+ LastUsedAt *time.Time `json:"last_used_at"`
+ } `json:"items"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Len(t, resp.Data.Items, 1)
+ require.WithinDuration(t, lastActiveAt, *resp.Data.Items[0].LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *resp.Data.Items[0].LastUsedAt, time.Second)
+}
+
+func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ lastActiveAt := lastLoginAt.Add(30 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(90 * time.Minute)
+
+ adminSvc := newStubAdminService()
+ adminSvc.users = []service.User{
+ {
+ ID: 8,
+ Email: "detail@example.com",
+ Username: "detail-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
+ CreatedAt: lastLoginAt.Add(-24 * time.Hour),
+ UpdatedAt: lastLoginAt,
+ },
+ }
+ handler := NewUserHandler(adminSvc, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Params = gin.Params{{Key: "id", Value: "8"}}
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/8", nil)
+
+ handler.GetByID(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ LastActiveAt *time.Time `json:"last_active_at"`
+ LastUsedAt *time.Time `json:"last_used_at"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.WithinDuration(t, lastActiveAt, *resp.Data.LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *resp.Data.LastUsedAt, time.Second)
+}
diff --git a/backend/internal/handler/auth_current_user_test.go b/backend/internal/handler/auth_current_user_test.go
new file mode 100644
index 00000000..cb3e4ba5
--- /dev/null
+++ b/backend/internal/handler/auth_current_user_test.go
@@ -0,0 +1,86 @@
+//go:build unit
+
+package handler
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 31,
+ Email: "me@example.com",
+ Username: "linuxdo-handle",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/linuxdo.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-31",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ "avatar_url": "https://cdn.example.com/linuxdo.png",
+ },
+ },
+ },
+ }
+
+ handler := &AuthHandler{
+ userService: service.NewUserService(repo, nil, nil, nil),
+ }
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 31})
+
+ handler.GetCurrentUser(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, true, resp.Data["email_bound"])
+ require.Equal(t, true, resp.Data["linuxdo_bound"])
+ require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, linuxdoBinding["bound"])
+
+ avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", avatarSource["provider"])
+ require.Equal(t, "linuxdo", avatarSource["source"])
+
+ profileSources, ok := resp.Data["profile_sources"].(map[string]any)
+ require.True(t, ok)
+ usernameSource, ok := profileSources["username"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", usernameSource["provider"])
+ require.Equal(t, "linuxdo", usernameSource["source"])
+}
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index f4ddf890..1f9a66ff 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -1,11 +1,13 @@
package handler
import (
+ "context"
"log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -46,6 +48,7 @@ type RegisterRequest struct {
TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码
InvitationCode string `json:"invitation_code"` // 邀请码
+ AffCode string `json:"aff_code"` // 邀请返利码
}
// SendVerifyCodeRequest 发送验证码请求
@@ -76,9 +79,24 @@ type AuthResponse struct {
User *dto.User `json:"user"`
}
+func ensureLoginUserActive(user *service.User) error {
+ if user == nil {
+ return infraerrors.Unauthorized("INVALID_USER", "user not found")
+ }
+ if !user.IsActive() {
+ return service.ErrUserNotActive
+ }
+ return nil
+}
+
// respondWithTokenPair 生成 Token 对并返回认证响应
// 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容)
func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
+ if err := ensureLoginUserActive(user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
slog.Error("failed to generate token pair", "error", err, "user_id", user.ID)
@@ -104,6 +122,34 @@ func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
})
}
+func (h *AuthHandler) ensureBackendModeAllowsUser(ctx context.Context, user *service.User) error {
+ if user == nil {
+ return infraerrors.Unauthorized("INVALID_USER", "user not found")
+ }
+ if h == nil || !h.isBackendModeEnabled(ctx) || user.IsAdmin() {
+ return nil
+ }
+ return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
+}
+
+func (h *AuthHandler) ensureBackendModeAllowsNewUserLogin(ctx context.Context) error {
+ if h == nil || !h.isBackendModeEnabled(ctx) {
+ return nil
+ }
+ return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
+}
+
+func (h *AuthHandler) isBackendModeEnabled(ctx context.Context) bool {
+ if h == nil || h.settingSvc == nil {
+ return false
+ }
+ settings, err := h.settingSvc.GetPublicSettings(ctx)
+ if err == nil && settings != nil {
+ return settings.BackendModeEnabled
+ }
+ return h.settingSvc.IsBackendModeEnabled(ctx)
+}
+
// Register handles user registration
// POST /api/v1/auth/register
func (h *AuthHandler) Register(c *gin.Context) {
@@ -119,7 +165,15 @@ func (h *AuthHandler) Register(c *gin.Context) {
return
}
- _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
+ _, user, err := h.authService.RegisterWithVerification(
+ c.Request.Context(),
+ req.Email,
+ req.Password,
+ req.VerifyCode,
+ req.PromoCode,
+ req.InvitationCode,
+ req.AffCode,
+ )
if err != nil {
response.ErrorFrom(c, err)
return
@@ -177,6 +231,11 @@ func (h *AuthHandler) Login(c *gin.Context) {
}
_ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA
@@ -194,11 +253,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
- // Backend mode: only admin can login
- if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
- response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
- return
- }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
h.respondWithTokenPair(c, user)
}
@@ -262,16 +317,80 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
-
- // Backend mode: only admin can login (check BEFORE deleting session)
- if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
- response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
+ if err := ensureLoginUserActive(user); err != nil {
+ response.ErrorFrom(c, err)
return
}
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if session.PendingOAuthBind != nil {
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ pendingSession, err := pendingSvc.GetBrowserSession(
+ c.Request.Context(),
+ session.PendingOAuthBind.PendingSessionToken,
+ session.PendingOAuthBind.BrowserSessionKey,
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, pendingSession.ID, oauthAdoptionDecisionRequest{})
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthBinding(
+ c.Request.Context(),
+ h.entClient(),
+ h.authService,
+ h.userService,
+ pendingSession,
+ decision,
+ &user.ID,
+ true,
+ true,
+ ); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+ if _, err := pendingSvc.ConsumeBrowserSession(
+ c.Request.Context(),
+ pendingSession.SessionToken,
+ pendingSession.BrowserSessionKey,
+ ); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+
+ user, err = h.userService.GetByID(c.Request.Context(), session.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
// Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
+ if session.PendingOAuthBind == nil {
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ }
+
h.respondWithTokenPair(c, user)
}
@@ -290,8 +409,14 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
return
}
+ identities, err := h.userService.GetProfileIdentitySummaries(c.Request.Context(), subject.UserID, user)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
type UserResponse struct {
- *dto.User
+ userProfileResponse
RunMode string `json:"run_mode"`
}
@@ -300,7 +425,10 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
runMode = h.cfg.RunMode
}
- response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
+ response.Success(c, UserResponse{
+ userProfileResponse: userProfileResponseFromService(user, identities),
+ RunMode: runMode,
+ })
}
// ValidatePromoCodeRequest 验证优惠码请求
@@ -578,6 +706,8 @@ func (h *AuthHandler) Logout(c *gin.Context) {
// 不影响登出流程
}
}
+ h.consumePendingOAuthSessionOnLogout(c)
+ clearOAuthLogoutCookies(c)
response.Success(c, LogoutResponse{
Message: "Logged out successfully",
@@ -598,7 +728,7 @@ func (h *AuthHandler) RevokeAllSessions(c *gin.Context) {
return
}
- if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil {
+ if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil {
slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err)
response.InternalError(c, "Failed to revoke sessions")
return
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index 0c7c2da7..7df4abfd 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -2,6 +2,8 @@ package handler
import (
"context"
+ "crypto/hmac"
+ "crypto/sha256"
"encoding/base64"
"errors"
"fmt"
@@ -13,10 +15,13 @@ import (
"time"
"unicode/utf8"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -25,17 +30,24 @@ import (
)
const (
- linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
- linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
- linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
- linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
- linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
- linuxDoOAuthDefaultRedirectTo = "/dashboard"
- linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
+ linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
+ oauthBindAccessTokenCookiePath = "/api/v1/auth/oauth"
+ linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
+ linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
+ linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
+ linuxDoOAuthIntentCookieName = "linuxdo_oauth_intent"
+ linuxDoOAuthBindUserCookieName = "linuxdo_oauth_bind_user"
+ oauthBindAccessTokenCookieName = "oauth_bind_access_token"
+ linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
+ linuxDoOAuthDefaultRedirectTo = "/dashboard"
+ linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
linuxDoOAuthMaxRedirectLen = 2048
linuxDoOAuthMaxFragmentValueLen = 512
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
+
+ oauthIntentLogin = "login"
+ oauthIntentBindCurrentUser = "bind_current_user"
)
type linuxDoTokenResponse struct {
@@ -87,9 +99,29 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
secureCookie := isRequestHTTPS(c)
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ intent := normalizeOAuthIntent(c.Query("intent"))
+ setCookie(c, linuxDoOAuthIntentCookieName, encodeCookieValue(intent), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ if intent == oauthIntentBindCurrentUser {
+ bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ setCookie(c, linuxDoOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ } else {
+ clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
+ }
codeChallenge := ""
if cfg.UsePKCE {
@@ -148,6 +180,8 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
+ clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
+ clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
}()
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
@@ -161,6 +195,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
+ intent, _ := readCookieDecoded(c, linuxDoOAuthIntentCookieName)
+ intent = normalizeOAuthIntent(intent)
codeVerifier := ""
if cfg.UsePKCE {
@@ -198,52 +239,205 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
- email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
+ email, username, subject, displayName, avatarURL, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
return
}
+ compatEmail := strings.TrimSpace(email)
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
if subject != "" {
email = linuxDoSyntheticEmail(subject)
}
-
- // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
- if err != nil {
- if errors.Is(err, service.ErrOAuthInvitationRequired) {
- pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
- if tokenErr != nil {
- redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
- return
- }
- fragment := url.Values{}
- fragment.Set("error", "invitation_required")
- fragment.Set("pending_oauth_token", pendingToken)
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ identityKey := service.PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: subject,
+ }
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "suggested_display_name": displayName,
+ "suggested_avatar_url": avatarURL,
+ }
+ if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
+ upstreamClaims["compat_email"] = compatEmail
+ }
+ if intent == oauthIntentBindCurrentUser {
+ targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "")
return
}
- // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
- redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentBindCurrentUser,
+ Identity: identityKey,
+ TargetUserID: &targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
return
}
- fragment := url.Values{}
- fragment.Set("access_token", tokenPair.AccessToken)
- fragment.Set("refresh_token", tokenPair.RefreshToken)
- fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
- fragment.Set("token_type", "Bearer")
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser != nil {
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identityKey,
+ TargetUserID: &existingIdentityUser.ID,
+ ResolvedEmail: existingIdentityUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ compatEmailUser, err := h.findLinuxDoCompatEmailUser(c.Request.Context(), compatEmail)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if err := h.createLinuxDoOAuthChoicePendingSession(
+ c,
+ identityKey,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ compatEmail,
+ compatEmailUser,
+ h.isForceEmailOnThirdPartySignup(c.Request.Context()),
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+}
+
+func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" ||
+ strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
+ return nil, nil
+ }
+
+ userEntity, err := client.User.Query().
+ Where(userNormalizedEmailPredicate(email)).
+ Order(dbent.Asc(dbuser.FieldID)).
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
+ }
+ switch len(userEntity) {
+ case 0:
+ return nil, nil
+ case 1:
+ return userEntity[0], nil
+ default:
+ return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
+ }
+}
+
+func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
+ c *gin.Context,
+ identity service.PendingAuthIdentityKey,
+ suggestedEmail string,
+ resolvedEmail string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ compatEmail string,
+ compatEmailUser *dbent.User,
+ forceEmailOnSignup bool,
+) error {
+ suggestionEmail := strings.TrimSpace(suggestedEmail)
+ canonicalEmail := strings.TrimSpace(resolvedEmail)
+ if suggestionEmail == "" {
+ suggestionEmail = canonicalEmail
+ }
+
+ completionResponse := map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "redirect": strings.TrimSpace(redirectTo),
+ "email": suggestionEmail,
+ "resolved_email": canonicalEmail,
+ "existing_account_email": "",
+ "existing_account_bindable": false,
+ "create_account_allowed": true,
+ "force_email_on_signup": forceEmailOnSignup,
+ "choice_reason": "third_party_signup",
+ }
+ if strings.TrimSpace(compatEmail) != "" {
+ completionResponse["compat_email"] = strings.TrimSpace(compatEmail)
+ }
+ resolvedChoiceEmail := suggestionEmail
+ if compatEmailUser != nil {
+ completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_bindable"] = true
+ completionResponse["choice_reason"] = "compat_email_match"
+ resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
+ }
+ if forceEmailOnSignup && compatEmailUser == nil {
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ }
+
+ var targetUserID *int64
+ if compatEmailUser != nil && compatEmailUser.ID > 0 {
+ targetUserID = &compatEmailUser.ID
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identity,
+ TargetUserID: targetUserID,
+ ResolvedEmail: resolvedChoiceEmail,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
}
type completeLinuxDoOAuthRequest struct {
- PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AffCode string `json:"aff_code,omitempty"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
@@ -256,17 +450,87 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return
}
- email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
return
}
-
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
if err != nil {
response.ErrorFrom(c, err)
return
}
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ } else if handled {
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
+ return
+ } else {
+ session = updatedSession
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+ if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
@@ -303,7 +567,7 @@ func linuxDoExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
- if cfg.UsePKCE {
+ if strings.TrimSpace(codeVerifier) != "" {
form.Set("code_verifier", codeVerifier)
}
@@ -353,11 +617,11 @@ func linuxDoFetchUserInfo(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
token *linuxDoTokenResponse,
-) (email string, username string, subject string, err error) {
+) (email string, username string, subject string, displayName string, avatarURL string, err error) {
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
- return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
+ return "", "", "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
@@ -366,16 +630,16 @@ func linuxDoFetchUserInfo(
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
- return "", "", "", fmt.Errorf("request userinfo: %w", err)
+ return "", "", "", "", "", fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
- return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
+ return "", "", "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return linuxDoParseUserInfo(resp.String(), cfg)
}
-func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
+func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, displayName string, avatarURL string, err error) {
email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
@@ -400,12 +664,29 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s
getGJSON(body, "user.id"),
)
+ displayName = firstNonEmpty(
+ getGJSON(body, "name"),
+ getGJSON(body, "nickname"),
+ getGJSON(body, "display_name"),
+ getGJSON(body, "user.name"),
+ getGJSON(body, "user.username"),
+ username,
+ )
+ avatarURL = firstNonEmpty(
+ getGJSON(body, "avatar_url"),
+ getGJSON(body, "avatar"),
+ getGJSON(body, "picture"),
+ getGJSON(body, "profile_image_url"),
+ getGJSON(body, "user.avatar"),
+ getGJSON(body, "user.avatar_url"),
+ )
+
subject = strings.TrimSpace(subject)
if subject == "" {
- return "", "", "", errors.New("userinfo missing id field")
+ return "", "", "", "", "", errors.New("userinfo missing id field")
}
if !isSafeLinuxDoSubject(subject) {
- return "", "", "", errors.New("userinfo returned invalid id field")
+ return "", "", "", "", "", errors.New("userinfo returned invalid id field")
}
email = strings.TrimSpace(email)
@@ -418,8 +699,13 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s
if username == "" {
username = "linuxdo_" + subject
}
+ displayName = strings.TrimSpace(displayName)
+ if displayName == "" {
+ displayName = username
+ }
+ avatarURL = strings.TrimSpace(avatarURL)
- return email, username, subject, nil
+ return email, username, subject, displayName, avatarURL, nil
}
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
@@ -436,7 +722,7 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
- if cfg.UsePKCE {
+ if strings.TrimSpace(codeChallenge) != "" {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
@@ -670,6 +956,30 @@ func clearCookie(c *gin.Context, name string, secure bool) {
})
}
+func clearOAuthBindAccessTokenCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthBindAccessTokenCookieName,
+ Value: "",
+ Path: oauthBindAccessTokenCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func setOAuthBindAccessTokenCookie(c *gin.Context, token string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthBindAccessTokenCookieName,
+ Value: url.QueryEscape(strings.TrimSpace(token)),
+ Path: oauthBindAccessTokenCookiePath,
+ MaxAge: linuxDoOAuthCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
func truncateFragmentValue(value string) string {
value = strings.TrimSpace(value)
if value == "" {
@@ -728,3 +1038,127 @@ func linuxDoSyntheticEmail(subject string) string {
}
return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain
}
+
+func normalizeOAuthIntent(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "", oauthIntentLogin:
+ return oauthIntentLogin
+ case "bind", oauthIntentBindCurrentUser:
+ return oauthIntentBindCurrentUser
+ default:
+ return oauthIntentLogin
+ }
+}
+
+func (h *AuthHandler) buildOAuthBindUserCookieFromContext(c *gin.Context) (string, error) {
+ userID, err := h.resolveOAuthBindTargetUserID(c)
+ if err != nil || userID == nil || *userID <= 0 {
+ return "", infraerrors.Unauthorized("UNAUTHORIZED", "authentication required")
+ }
+ return buildOAuthBindUserCookieValue(*userID, h.oauthBindCookieSecret())
+}
+
+func (h *AuthHandler) PrepareOAuthBindAccessTokenCookie(c *gin.Context) {
+ const bearerPrefix = "Bearer "
+
+ authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
+ if !strings.HasPrefix(strings.ToLower(authHeader), strings.ToLower(bearerPrefix)) {
+ response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
+ return
+ }
+
+ token := strings.TrimSpace(authHeader[len(bearerPrefix):])
+ if token == "" {
+ response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
+ return
+ }
+
+ setOAuthBindAccessTokenCookie(c, token, isRequestHTTPS(c))
+ c.Status(http.StatusNoContent)
+ c.Writer.WriteHeaderNow()
+}
+
+func (h *AuthHandler) resolveOAuthBindTargetUserID(c *gin.Context) (*int64, error) {
+ if subject, ok := servermiddleware.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
+ return &subject.UserID, nil
+ }
+ if h == nil || h.authService == nil || h.userService == nil {
+ return nil, service.ErrInvalidToken
+ }
+
+ ck, err := c.Request.Cookie(oauthBindAccessTokenCookieName)
+ clearOAuthBindAccessTokenCookie(c, isRequestHTTPS(c))
+ if err != nil {
+ return nil, err
+ }
+
+ tokenString, err := url.QueryUnescape(strings.TrimSpace(ck.Value))
+ if err != nil {
+ return nil, err
+ }
+ if tokenString == "" {
+ return nil, service.ErrInvalidToken
+ }
+
+ claims, err := h.authService.ValidateToken(tokenString)
+ if err != nil {
+ return nil, err
+ }
+ user, err := h.userService.GetByID(c.Request.Context(), claims.UserID)
+ if err != nil {
+ return nil, err
+ }
+ if user == nil || !user.IsActive() || claims.TokenVersion != user.TokenVersion {
+ return nil, service.ErrInvalidToken
+ }
+ return &user.ID, nil
+}
+
+func (h *AuthHandler) readOAuthBindUserIDFromCookie(c *gin.Context, cookieName string) (int64, error) {
+ value, err := readCookieDecoded(c, cookieName)
+ if err != nil {
+ return 0, err
+ }
+ return parseOAuthBindUserCookieValue(value, h.oauthBindCookieSecret())
+}
+
+func (h *AuthHandler) oauthBindCookieSecret() string {
+ if h == nil || h.cfg == nil {
+ return ""
+ }
+ return strings.TrimSpace(h.cfg.JWT.Secret)
+}
+
+func buildOAuthBindUserCookieValue(userID int64, secret string) (string, error) {
+ secret = strings.TrimSpace(secret)
+ if userID <= 0 || secret == "" {
+ return "", errors.New("invalid oauth bind cookie input")
+ }
+ payload := strconv.FormatInt(userID, 10)
+ mac := hmac.New(sha256.New, []byte(secret))
+ _, _ = mac.Write([]byte(payload))
+ signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+ return payload + "." + signature, nil
+}
+
+func parseOAuthBindUserCookieValue(value string, secret string) (int64, error) {
+ secret = strings.TrimSpace(secret)
+ if secret == "" {
+ return 0, errors.New("missing oauth bind cookie secret")
+ }
+ payload, signature, ok := strings.Cut(strings.TrimSpace(value), ".")
+ if !ok || payload == "" || signature == "" {
+ return 0, errors.New("invalid oauth bind cookie")
+ }
+ mac := hmac.New(sha256.New, []byte(secret))
+ _, _ = mac.Write([]byte(payload))
+ expectedSignature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+ if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
+ return 0, errors.New("invalid oauth bind cookie signature")
+ }
+ userID, err := strconv.ParseInt(payload, 10, 64)
+ if err != nil || userID <= 0 {
+ return 0, errors.New("invalid oauth bind cookie user")
+ }
+ return userID, nil
+}
diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go
index ff169c52..8b01ab41 100644
--- a/backend/internal/handler/auth_linuxdo_oauth_test.go
+++ b/backend/internal/handler/auth_linuxdo_oauth_test.go
@@ -1,10 +1,24 @@
package handler
import (
+ "bytes"
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
"strings"
"testing"
+ "time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
+ servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -41,11 +55,13 @@ func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
+ email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":123,"username":"alice","name":"Alice","avatar_url":"https://cdn.example/avatar.png"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "alice", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
+ require.Equal(t, "Alice", displayName)
+ require.Equal(t, "https://cdn.example/avatar.png", avatarURL)
}
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
@@ -53,11 +69,13 @@ func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
+ email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "linuxdo_123", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
+ require.Equal(t, "linuxdo_123", displayName)
+ require.Equal(t, "", avatarURL)
}
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
@@ -65,11 +83,11 @@ func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
+ _, _, _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
require.Error(t, err)
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
- _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
+ _, _, _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
require.Error(t, err)
}
@@ -106,3 +124,906 @@ func TestSingleLineStripsWhitespace(t *testing.T) {
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
require.Equal(t, "", singleLine("\n\t\r"))
}
+
+func TestLinuxDoOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
+ handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: "https://connect.linux.do/oauth/authorize",
+ TokenURL: "https://connect.linux.do/oauth/token",
+ UserInfoURL: "https://connect.linux.do/api/user",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=/settings/connections", nil)
+ c.Request = req
+ c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 42})
+
+ handler.LinuxDoOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.Contains(t, location, "connect.linux.do/oauth/authorize")
+ require.Contains(t, location, "client_id=linuxdo-client")
+ require.Contains(t, location, "code_challenge=")
+
+ cookies := recorder.Result().Cookies()
+ require.NotNil(t, findCookie(cookies, linuxDoOAuthStateCookieName))
+ require.NotNil(t, findCookie(cookies, linuxDoOAuthRedirectCookie))
+ require.NotNil(t, findCookie(cookies, linuxDoOAuthVerifierCookie))
+ require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName))
+
+ intentCookie := findCookie(cookies, linuxDoOAuthIntentCookieName)
+ require.NotNil(t, intentCookie)
+ require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value))
+
+ bindCookie := findCookie(cookies, linuxDoOAuthBindUserCookieName)
+ require.NotNil(t, bindCookie)
+ userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret")
+ require.NoError(t, err)
+ require.Equal(t, int64(42), userID)
+}
+
+func TestLinuxDoOAuthStartOmitsPKCEWhenDisabled(t *testing.T) {
+ handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: "https://connect.linux.do/oauth/authorize",
+ TokenURL: "https://connect.linux.do/oauth/token",
+ UserInfoURL: "https://connect.linux.do/api/user",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: false,
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?redirect=/dashboard", nil)
+
+ handler.LinuxDoOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.NotContains(t, recorder.Header().Get("Location"), "code_challenge=")
+ require.Nil(t, findCookie(recorder.Result().Cookies(), linuxDoOAuthVerifierCookie))
+}
+
+func TestLinuxDoOAuthCallbackAllowsMissingVerifierWhenPKCEDisabled(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ require.NoError(t, r.ParseForm())
+ require.Empty(t, r.PostForm.Get("code_verifier"))
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"compat-subject","username":"linuxdo_user","name":"LinuxDo Display"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: false,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=linuxdo-code&state=state-123", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+ require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+}
+
+func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) {
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: "https://connect.linux.do/oauth/authorize",
+ TokenURL: "https://connect.linux.do/oauth/token",
+ UserInfoURL: "https://connect.linux.do/api/user",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ user, err := client.User.Create().
+ SetEmail("bind-cookie@example.com").
+ SetUsername("bind-cookie-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(context.Background())
+ require.NoError(t, err)
+
+ token, err := handler.authService.GenerateToken(&service.User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ PasswordHash: user.PasswordHash,
+ Role: user.Role,
+ Status: user.Status,
+ })
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=/settings/connections", nil)
+ req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: token, Path: oauthBindAccessTokenCookiePath})
+ c.Request = req
+
+ handler.LinuxDoOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+
+ bindCookie := findCookie(recorder.Result().Cookies(), linuxDoOAuthBindUserCookieName)
+ require.NotNil(t, bindCookie)
+ userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret")
+ require.NoError(t, err)
+ require.Equal(t, user.ID, userID)
+
+ accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName)
+ require.NotNil(t, accessTokenCookie)
+ require.Equal(t, -1, accessTokenCookie.MaxAge)
+}
+
+func TestPrepareOAuthBindAccessTokenCookieSetsHttpOnlyCookie(t *testing.T) {
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{})
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/bind-token", nil)
+ req.Header.Set("Authorization", "Bearer access-token-value")
+ c.Request = req
+
+ handler.PrepareOAuthBindAccessTokenCookie(c)
+
+ require.Equal(t, http.StatusNoContent, recorder.Code)
+ accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName)
+ require.NotNil(t, accessTokenCookie)
+ require.Equal(t, oauthBindAccessTokenCookiePath, accessTokenCookie.Path)
+ require.Equal(t, linuxDoOAuthCookieMaxAgeSec, accessTokenCookie.MaxAge)
+ require.True(t, accessTokenCookie.HttpOnly)
+ require.Equal(t, url.QueryEscape("access-token-value"), accessTokenCookie.Value)
+}
+
+func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"321","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(linuxDoSyntheticEmail("321")).
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("321").
+ SetMetadata(map[string]any{"username": "legacy-user"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-123&state=state-123", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-123"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, linuxDoSyntheticEmail("321"), session.ResolvedEmail)
+ require.Equal(t, "LinuxDo Display", session.UpstreamIdentityClaims["suggested_display_name"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+ _, hasRefreshToken := completion["refresh_token"]
+ require.False(t, hasRefreshToken)
+ require.Nil(t, completion["error"])
+}
+
+func TestLinuxDoOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_disabled","name":"LinuxDo Disabled"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(linuxDoSyntheticEmail("654")).
+ SetUsername("disabled-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusDisabled).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("654").
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-disabled&state=state-disabled", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-disabled"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-disabled"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"321","email":"legacy@example.com","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(" Legacy@Example.com ").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-compat&state=state-compat", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-compat"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-compat"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, strings.TrimSpace(existingUser.Email), session.ResolvedEmail)
+ require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, strings.TrimSpace(existingUser.Email), completion["email"])
+ require.Equal(t, strings.TrimSpace(existingUser.Email), completion["existing_account_email"])
+ require.Equal(t, true, completion["existing_account_bindable"])
+ require.Equal(t, "compat_email_match", completion["choice_reason"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+}
+
+func TestLinuxDoOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_invite","name":"Need Invite","avatar_url":"https://cdn.example/invite.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, true, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-456&state=state-456", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-456"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-456"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.Nil(t, session.TargetUserID)
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, "third_party_signup", completion["choice_reason"])
+}
+
+func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"999","username":"bind_user","name":"Bind Display","avatar_url":"https://cdn.example/bind.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-bind&state=state-bind", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-bind"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/settings/connections"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-bind"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentBindCurrentUser))
+ req.AddCookie(encodedCookie(linuxDoOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentBindCurrentUser, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, currentUser.ID, *session.TargetUserID)
+ require.Equal(t, linuxDoSyntheticEmail("999"), session.ResolvedEmail)
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/settings/connections", completion["redirect"])
+ require.Empty(t, completion["access_token"])
+ require.Equal(t, "Bind Display", session.UpstreamIdentityClaims["suggested_display_name"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, userCount)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-subject-1").
+ SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "LinuxDo Display",
+ "suggested_avatar_url": "https://cdn.example/linuxdo.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "LinuxDo Display", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("linuxdo-subject-1"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-invalid-session").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-invalid-subject-1").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("linuxdo-invalid-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "bind_login_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-invalid-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-choice-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-choice-subject-1").
+ SetResolvedEmail("linuxdo-choice-subject-1@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-choice-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": oauthPendingChoiceStep,
+ "redirect": "/dashboard",
+ "email": "fresh@example.com",
+ "resolved_email": "fresh@example.com",
+ "force_email_on_signup": true,
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-choice-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.Equal(t, "pending_session", responseData["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, responseData["step"])
+ require.Equal(t, true, responseData["force_email_on_signup"])
+ require.Empty(t, responseData["access_token"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-no-adoption-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-subject-no-adoption").
+ SetResolvedEmail("linuxdo-subject-no-adoption@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-browser-no-adoption").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "LinuxDo Legacy",
+ "suggested_avatar_url": "https://cdn.example/linuxdo-legacy.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser-no-adoption")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+ require.NotEmpty(t, responseData["refresh_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "linuxdo_user", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("linuxdo-subject-no-adoption"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingOwner.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-conflict-subject").
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-conflict-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-conflict-subject").
+ SetResolvedEmail("linuxdo-conflict-subject@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-conflict-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-conflict-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusConflict, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
+
+ userCount, err := client.User.Query().
+ Where(dbuser.EmailEQ("linuxdo-conflict-subject@linuxdo-connect.invalid")).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
+ t.Helper()
+ handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
+ return handler
+}
+
+func newLinuxDoOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+ handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled)
+ handler.settingSvc = nil
+ handler.cfg = &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ LinuxDo: oauthCfg,
+ }
+ return handler, client
+}
diff --git a/backend/internal/handler/auth_oauth_logout_test.go b/backend/internal/handler/auth_oauth_logout_test.go
new file mode 100644
index 00000000..0d4f94b1
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_logout_test.go
@@ -0,0 +1,68 @@
+package handler
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("logout-pending-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("logout-subject-123").
+ SetBrowserSessionKey("logout-browser-session-key").
+ SetResolvedEmail("logout@example.com").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("logout-browser-session-key")})
+ req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-access-token"})
+ req.AddCookie(&http.Cookie{Name: linuxDoOAuthStateCookieName, Value: encodeCookieValue("linuxdo-state")})
+ req.AddCookie(&http.Cookie{Name: oidcOAuthStateCookieName, Value: encodeCookieValue("oidc-state")})
+ req.AddCookie(&http.Cookie{Name: wechatOAuthStateCookieName, Value: encodeCookieValue("wechat-state")})
+ req.AddCookie(&http.Cookie{Name: wechatPaymentOAuthStateName, Value: encodeCookieValue("wechat-payment-state")})
+ ginCtx.Request = req
+
+ handler.Logout(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ cookies := recorder.Result().Cookies()
+ for _, name := range []string{
+ oauthPendingSessionCookieName,
+ oauthPendingBrowserCookieName,
+ oauthBindAccessTokenCookieName,
+ linuxDoOAuthStateCookieName,
+ oidcOAuthStateCookieName,
+ wechatOAuthStateCookieName,
+ wechatPaymentOAuthStateName,
+ } {
+ cookie := findCookie(cookies, name)
+ require.NotNil(t, cookie, name)
+ require.Equal(t, -1, cookie.MaxAge, name)
+ require.True(t, cookie.HttpOnly, name)
+ }
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
new file mode 100644
index 00000000..490afd0f
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -0,0 +1,1946 @@
+package handler
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ entsql "entgo.io/ent/dialect/sql"
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ oauthPendingBrowserCookiePath = "/api/v1/auth/oauth"
+ oauthPendingBrowserCookieName = "oauth_pending_browser_session"
+ oauthPendingSessionCookiePath = "/api/v1/auth/oauth"
+ oauthPendingSessionCookieName = "oauth_pending_session"
+ oauthPendingCookieMaxAgeSec = 10 * 60
+ oauthPendingChoiceStep = "choose_account_action_required"
+
+ oauthCompletionResponseKey = "completion_response"
+)
+
+var pendingOAuthCreateAccountPreCommitHook func(context.Context, *dbent.PendingAuthSession) error
+
+type oauthPendingSessionPayload struct {
+ Intent string
+ Identity service.PendingAuthIdentityKey
+ TargetUserID *int64
+ ResolvedEmail string
+ RedirectTo string
+ BrowserSessionKey string
+ UpstreamIdentityClaims map[string]any
+ CompletionResponse map[string]any
+}
+
+type oauthAdoptionDecisionRequest struct {
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+type bindPendingOAuthLoginRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ Password string `json:"password" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+type createPendingOAuthAccountRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ VerifyCode string `json:"verify_code,omitempty"`
+ Password string `json:"password" binding:"required,min=6"`
+ InvitationCode string `json:"invitation_code,omitempty"`
+ AffCode string `json:"aff_code,omitempty"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+type sendPendingOAuthVerifyCodeRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ TurnstileToken string `json:"turnstile_token,omitempty"`
+ PendingAuthToken string `json:"pending_auth_token,omitempty"`
+ PendingOAuthToken string `json:"pending_oauth_token,omitempty"`
+}
+
+func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest {
+ return oauthAdoptionDecisionRequest{
+ AdoptDisplayName: r.AdoptDisplayName,
+ AdoptAvatar: r.AdoptAvatar,
+ }
+}
+
+func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest {
+ return oauthAdoptionDecisionRequest{
+ AdoptDisplayName: r.AdoptDisplayName,
+ AdoptAvatar: r.AdoptAvatar,
+ }
+}
+
+func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
+ if h == nil || h.authService == nil || h.authService.EntClient() == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ return service.NewAuthPendingIdentityService(h.authService.EntClient()), nil
+}
+
+func generateOAuthPendingBrowserSession() (string, error) {
+ return oauth.GenerateState()
+}
+
+func setOAuthPendingBrowserCookie(c *gin.Context, sessionKey string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingBrowserCookieName,
+ Value: encodeCookieValue(sessionKey),
+ Path: oauthPendingBrowserCookiePath,
+ MaxAge: oauthPendingCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func clearOAuthPendingBrowserCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingBrowserCookieName,
+ Value: "",
+ Path: oauthPendingBrowserCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func readOAuthPendingBrowserCookie(c *gin.Context) (string, error) {
+ return readCookieDecoded(c, oauthPendingBrowserCookieName)
+}
+
+func setOAuthPendingSessionCookie(c *gin.Context, sessionToken string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingSessionCookieName,
+ Value: encodeCookieValue(sessionToken),
+ Path: oauthPendingSessionCookiePath,
+ MaxAge: oauthPendingCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func clearOAuthPendingSessionCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingSessionCookieName,
+ Value: "",
+ Path: oauthPendingSessionCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func readOAuthPendingSessionCookie(c *gin.Context) (string, error) {
+ return readCookieDecoded(c, oauthPendingSessionCookieName)
+}
+
+func redirectToFrontendCallback(c *gin.Context, frontendCallback string) {
+ u, err := url.Parse(frontendCallback)
+ if err != nil {
+ c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
+ return
+ }
+ if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
+ c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
+ return
+ }
+ u.Fragment = ""
+ c.Header("Cache-Control", "no-store")
+ c.Header("Pragma", "no-cache")
+ c.Redirect(http.StatusFound, u.String())
+}
+
+func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPendingSessionPayload) error {
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return err
+ }
+
+ session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{
+ Intent: strings.TrimSpace(payload.Intent),
+ Identity: payload.Identity,
+ TargetUserID: payload.TargetUserID,
+ ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail),
+ RedirectTo: strings.TrimSpace(payload.RedirectTo),
+ BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey),
+ UpstreamIdentityClaims: payload.UpstreamIdentityClaims,
+ LocalFlowState: map[string]any{
+ oauthCompletionResponseKey: payload.CompletionResponse,
+ },
+ })
+ if err != nil {
+ return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
+ }
+
+ setOAuthPendingSessionCookie(c, session.SessionToken, isRequestHTTPS(c))
+ return nil
+}
+
+func readCompletionResponse(session map[string]any) (map[string]any, bool) {
+ if len(session) == 0 {
+ return nil, false
+ }
+ value, ok := session[oauthCompletionResponseKey]
+ if !ok {
+ return nil, false
+ }
+ result, ok := value.(map[string]any)
+ if !ok {
+ return nil, false
+ }
+ return result, true
+}
+
+func clonePendingMap(values map[string]any) map[string]any {
+ if len(values) == 0 {
+ return map[string]any{}
+ }
+ cloned := make(map[string]any, len(values))
+ for key, value := range values {
+ cloned[key] = value
+ }
+ return cloned
+}
+
+func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any {
+ payload, _ := readCompletionResponse(session.LocalFlowState)
+ merged := clonePendingMap(payload)
+ if strings.TrimSpace(session.RedirectTo) != "" {
+ if _, exists := merged["redirect"]; !exists {
+ merged["redirect"] = session.RedirectTo
+ }
+ }
+ for key, value := range overrides {
+ if value == nil {
+ delete(merged, key)
+ continue
+ }
+ merged[key] = value
+ }
+ applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims)
+ return merged
+}
+
+func pendingSessionStringValue(values map[string]any, key string) string {
+ if len(values) == 0 {
+ return ""
+ }
+ raw, ok := values[key]
+ if !ok {
+ return ""
+ }
+ value, ok := raw.(string)
+ if !ok {
+ return ""
+ }
+ return strings.TrimSpace(value)
+}
+
+func pendingSessionWantsInvitation(payload map[string]any) bool {
+ return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
+}
+
+func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
+ if session == nil {
+ return false
+ }
+ if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
+ return false
+ }
+ if session.TargetUserID == nil || *session.TargetUserID <= 0 {
+ return false
+ }
+ if pendingSessionWantsInvitation(payload) {
+ return false
+ }
+ return strings.TrimSpace(pendingSessionStringValue(payload, "step")) == ""
+}
+
+func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error {
+ if session == nil {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ if strings.TrimSpace(session.Intent) != oauthIntentLogin {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ payload, _ := readCompletionResponse(session.LocalFlowState)
+ if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required") {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ return nil
+}
+
+func buildLegacyCompleteRegistrationPendingResponse(
+ session *dbent.PendingAuthSession,
+ forceEmailOnSignup bool,
+ emailVerificationRequired bool,
+) map[string]any {
+ completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "create_account_allowed": true,
+ "force_email_on_signup": forceEmailOnSignup,
+ }))
+
+ if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
+ if _, exists := completionResponse["email"]; !exists {
+ completionResponse["email"] = email
+ }
+ if _, exists := completionResponse["resolved_email"]; !exists {
+ completionResponse["resolved_email"] = email
+ }
+ }
+ if _, exists := completionResponse["choice_reason"]; !exists {
+ switch {
+ case forceEmailOnSignup:
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ case emailVerificationRequired:
+ completionResponse["choice_reason"] = "email_verification_required"
+ default:
+ completionResponse["choice_reason"] = "third_party_signup"
+ }
+ }
+ return completionResponse
+}
+
+func (h *AuthHandler) legacyCompleteRegistrationSessionStatus(
+ c *gin.Context,
+ session *dbent.PendingAuthSession,
+) (*dbent.PendingAuthSession, bool, error) {
+ if session == nil {
+ return nil, false, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+
+ payload := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
+ if step := pendingSessionStringValue(payload, "step"); step != "" {
+ return session, true, nil
+ }
+
+ emailVerificationRequired := h != nil && h.authService != nil && h.authService.IsEmailVerifyEnabled(c.Request.Context())
+ forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context())
+ if !emailVerificationRequired && !forceEmailOnSignup {
+ return session, false, nil
+ }
+
+ client := h.entClient()
+ if client == nil {
+ return nil, false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ updatedSession, err := updatePendingOAuthSessionProgress(
+ c.Request.Context(),
+ client,
+ session,
+ strings.TrimSpace(session.Intent),
+ strings.TrimSpace(session.ResolvedEmail),
+ nil,
+ buildLegacyCompleteRegistrationPendingResponse(session, forceEmailOnSignup, emailVerificationRequired),
+ )
+ if err != nil {
+ return nil, false, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
+ }
+ return updatedSession, true, nil
+}
+
+func (r oauthAdoptionDecisionRequest) hasDecision() bool {
+ return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
+}
+
+func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) {
+ var req oauthAdoptionDecisionRequest
+ if c == nil || c.Request == nil || c.Request.Body == nil {
+ return req, nil
+ }
+ if err := c.ShouldBindJSON(&req); err != nil {
+ if errors.Is(err, io.EOF) {
+ return req, nil
+ }
+ return req, err
+ }
+ return req, nil
+}
+
+func cloneOAuthMetadata(values map[string]any) map[string]any {
+ if len(values) == 0 {
+ return map[string]any{}
+ }
+ cloned := make(map[string]any, len(values))
+ for key, value := range values {
+ cloned[key] = value
+ }
+ return cloned
+}
+
+func mergeOAuthMetadata(base map[string]any, overlay map[string]any) map[string]any {
+ merged := cloneOAuthMetadata(base)
+ for key, value := range overlay {
+ merged[key] = value
+ }
+ return merged
+}
+
+func normalizeAdoptedOAuthDisplayName(value string) string {
+ value = strings.TrimSpace(value)
+ if len([]rune(value)) > 100 {
+ value = string([]rune(value)[:100])
+ }
+ return value
+}
+
+func (h *AuthHandler) entClient() *dbent.Client {
+ if h == nil || h.authService == nil {
+ return nil
+ }
+ return h.authService.EntClient()
+}
+
+func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool {
+ if h == nil || h.settingSvc == nil {
+ return false
+ }
+ defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx)
+ if err != nil || defaults == nil {
+ return false
+ }
+ return defaults.ForceEmailOnThirdPartySignup
+}
+
+func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ record, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ return findActiveUserByID(ctx, client, record.UserID)
+}
+
+func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
+func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") }
+func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") }
+func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") }
+
+func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "linuxdo")
+}
+
+func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") }
+
+func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "wechat")
+}
+
+func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "")
+}
+
+// SendPendingOAuthVerifyCode sends a verification code for a browser-bound
+// pending OAuth account-creation flow.
+// POST /api/v1/auth/oauth/pending/send-verify-code
+func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
+ var req sendPendingOAuthVerifyCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ _, session, _, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+
+ email := strings.TrimSpace(strings.ToLower(req.Email))
+ if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
+ session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ } else if err != nil && !errors.Is(err, service.ErrUserNotFound) {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, SendVerifyCodeResponse{
+ Message: "Verification code sent successfully",
+ Countdown: result.Countdown,
+ })
+}
+
+func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
+ c *gin.Context,
+ sessionID int64,
+ req oauthAdoptionDecisionRequest,
+) (*dbent.IdentityAdoptionDecision, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ existing, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)).
+ Only(c.Request.Context())
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err)
+ }
+ if existing != nil && !req.hasDecision() {
+ return existing, nil
+ }
+ if existing == nil && !req.hasDecision() {
+ return nil, nil
+ }
+
+ input := service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: sessionID,
+ }
+ if existing != nil {
+ input.AdoptDisplayName = existing.AdoptDisplayName
+ input.AdoptAvatar = existing.AdoptAvatar
+ input.IdentityID = existing.IdentityID
+ }
+ if req.AdoptDisplayName != nil {
+ input.AdoptDisplayName = *req.AdoptDisplayName
+ }
+ if req.AdoptAvatar != nil {
+ input.AdoptAvatar = *req.AdoptAvatar
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return nil, err
+ }
+ decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input)
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
+ }
+ return decision, nil
+}
+
+func (h *AuthHandler) ensurePendingOAuthAdoptionDecision(
+ c *gin.Context,
+ sessionID int64,
+ req oauthAdoptionDecisionRequest,
+) (*dbent.IdentityAdoptionDecision, error) {
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req)
+ if err != nil {
+ return nil, err
+ }
+ if decision != nil {
+ return decision, nil
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return nil, err
+ }
+ decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: sessionID,
+ })
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
+ }
+ return decision, nil
+}
+
+func updatePendingOAuthSessionProgress(
+ ctx context.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ intent string,
+ resolvedEmail string,
+ targetUserID *int64,
+ completionResponse map[string]any,
+) (*dbent.PendingAuthSession, error) {
+ if client == nil || session == nil {
+ return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
+ }
+
+ localFlowState := clonePendingMap(session.LocalFlowState)
+ localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse)
+
+ update := client.PendingAuthSession.UpdateOneID(session.ID).
+ SetIntent(strings.TrimSpace(intent)).
+ SetResolvedEmail(strings.TrimSpace(resolvedEmail)).
+ SetLocalFlowState(localFlowState)
+ if targetUserID != nil && *targetUserID > 0 {
+ update = update.SetTargetUserID(*targetUserID)
+ } else {
+ update = update.ClearTargetUserID()
+ }
+ return update.Save(ctx)
+}
+
+func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
+ if session == nil {
+ return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 {
+ return *session.TargetUserID, nil
+ }
+ email := strings.TrimSpace(session.ResolvedEmail)
+ if email == "" {
+ return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing")
+ }
+
+ userEntity, err := findUserByNormalizedEmail(ctx, client, email)
+ if err != nil {
+ if errors.Is(err, service.ErrUserNotFound) {
+ return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found")
+ }
+ return 0, err
+ }
+ return userEntity.ID, nil
+}
+
+func userNormalizedEmailPredicate(email string) predicate.User {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" {
+ return dbuser.EmailEQ(email)
+ }
+ return predicate.User(func(s *entsql.Selector) {
+ s.Where(entsql.P(func(b *entsql.Builder) {
+ b.WriteString("LOWER(TRIM(").
+ Ident(s.C(dbuser.FieldEmail)).
+ WriteString(")) = ").
+ Arg(normalized)
+ }))
+ })
+}
+
+func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email string) (*dbent.User, error) {
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ matches, err := client.User.Query().
+ Where(userNormalizedEmailPredicate(email)).
+ Order(dbent.Asc(dbuser.FieldID)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if len(matches) == 0 {
+ return nil, service.ErrUserNotFound
+ }
+ if len(matches) > 1 {
+ return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
+ }
+ return matches[0], nil
+}
+
+func ensurePendingOAuthRegistrationIdentityAvailable(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) error {
+ if client == nil || session == nil {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity == nil || identity.UserID <= 0 {
+ return nil
+ }
+
+ activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
+ if err != nil {
+ return err
+ }
+ if activeOwner != nil {
+ return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ return nil
+}
+
+func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
+ if session == nil {
+ return nil
+ }
+ switch strings.TrimSpace(session.ProviderType) {
+ case "oidc":
+ issuer := strings.TrimSpace(session.ProviderKey)
+ if issuer == "" {
+ issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
+ }
+ if issuer == "" {
+ return nil
+ }
+ return &issuer
+ default:
+ issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
+ if issuer == "" {
+ return nil
+ }
+ return &issuer
+ }
+}
+
+func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
+ if session != nil && strings.EqualFold(strings.TrimSpace(session.ProviderType), "wechat") {
+ return ensurePendingWeChatOAuthIdentityForUser(ctx, tx, session, userID)
+ }
+
+ client := tx.Client()
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, err
+ }
+ if identity != nil {
+ if identity.UserID != userID {
+ activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
+ if err != nil {
+ return nil, err
+ }
+ if activeOwner != nil {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ return client.AuthIdentity.UpdateOneID(identity.ID).
+ SetUserID(userID).
+ Save(ctx)
+ }
+ return identity, nil
+ }
+
+ create := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(strings.TrimSpace(session.ProviderType)).
+ SetProviderKey(strings.TrimSpace(session.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(session.ProviderSubject)).
+ SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims))
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ create = create.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ return create.Save(ctx)
+}
+
+func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
+ client := tx.Client()
+ providerType := strings.TrimSpace(session.ProviderType)
+ providerKey := strings.TrimSpace(session.ProviderKey)
+ providerSubject := strings.TrimSpace(session.ProviderSubject)
+ providerKeys := wechatCompatibleProviderKeys(providerKey)
+ channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel"))
+ channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id"))
+ channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject"))
+ metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims)
+
+ identityRecords, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ ).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ identity, hasCanonicalKey, err := chooseWeChatIdentityForUser(ctx, client, identityRecords, userID, providerKey)
+ if err != nil {
+ return nil, err
+ }
+
+ var legacyOpenIDIdentity *dbent.AuthIdentity
+ if channelSubject != "" && channelSubject != providerSubject {
+ legacyOpenIDRecords, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(channelSubject),
+ ).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ legacyOpenIDIdentity, _, err = chooseWeChatIdentityForUser(ctx, client, legacyOpenIDRecords, userID, providerKey)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ switch {
+ case identity != nil:
+ update := client.AuthIdentity.UpdateOneID(identity.ID).
+ SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata))
+ if identity.UserID != userID {
+ update = update.SetUserID(userID)
+ }
+ if !strings.EqualFold(strings.TrimSpace(identity.ProviderKey), providerKey) && !hasCanonicalKey {
+ update = update.SetProviderKey(providerKey)
+ }
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ identity, err = update.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ case legacyOpenIDIdentity != nil:
+ update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID).
+ SetProviderKey(providerKey).
+ SetProviderSubject(providerSubject).
+ SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata))
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ identity, err = update.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ default:
+ create := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(providerType).
+ SetProviderKey(providerKey).
+ SetProviderSubject(providerSubject).
+ SetMetadata(metadata)
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ create = create.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ identity, err = create.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if channel == "" || channelAppID == "" || channelSubject == "" {
+ return identity, nil
+ }
+
+ channelRecords, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(providerType),
+ authidentitychannel.ProviderKeyIn(providerKeys...),
+ authidentitychannel.ChannelEQ(channel),
+ authidentitychannel.ChannelAppIDEQ(channelAppID),
+ authidentitychannel.ChannelSubjectEQ(channelSubject),
+ ).
+ WithIdentity().
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ channelRecord, hasCanonicalChannelKey, err := chooseWeChatChannelForUser(ctx, client, channelRecords, userID, providerKey)
+ if err != nil {
+ return nil, err
+ }
+
+ channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata)
+ if channelRecord == nil {
+ if _, err := client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(providerType).
+ SetProviderKey(providerKey).
+ SetChannel(channel).
+ SetChannelAppID(channelAppID).
+ SetChannelSubject(channelSubject).
+ SetMetadata(channelMetadata).
+ Save(ctx); err != nil {
+ return nil, err
+ }
+ return identity, nil
+ }
+
+ updateChannel := client.AuthIdentityChannel.UpdateOneID(channelRecord.ID).
+ SetIdentityID(identity.ID).
+ SetMetadata(channelMetadata)
+ if !strings.EqualFold(strings.TrimSpace(channelRecord.ProviderKey), providerKey) && !hasCanonicalChannelKey {
+ updateChannel = updateChannel.SetProviderKey(providerKey)
+ }
+ _, err = updateChannel.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return identity, nil
+}
+
+func chooseWeChatIdentityForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentity, userID int64, preferredProviderKey string) (*dbent.AuthIdentity, bool, error) {
+ var preferred *dbent.AuthIdentity
+ var fallback *dbent.AuthIdentity
+ hasCanonicalKey := false
+ for _, record := range records {
+ if record == nil {
+ continue
+ }
+ if record.UserID != userID {
+ activeOwner, err := findActiveUserByID(ctx, client, record.UserID)
+ if err != nil {
+ return nil, false, err
+ }
+ if activeOwner != nil {
+ return nil, false, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ }
+ if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) {
+ hasCanonicalKey = true
+ if preferred == nil {
+ preferred = record
+ }
+ continue
+ }
+ if fallback == nil {
+ fallback = record
+ }
+ }
+ if preferred != nil {
+ return preferred, hasCanonicalKey, nil
+ }
+ return fallback, hasCanonicalKey, nil
+}
+
+func chooseWeChatChannelForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentityChannel, userID int64, preferredProviderKey string) (*dbent.AuthIdentityChannel, bool, error) {
+ var preferred *dbent.AuthIdentityChannel
+ var fallback *dbent.AuthIdentityChannel
+ hasCanonicalKey := false
+ for _, record := range records {
+ if record == nil {
+ continue
+ }
+ if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
+ activeOwner, err := findActiveUserByID(ctx, client, record.Edges.Identity.UserID)
+ if err != nil {
+ return nil, false, err
+ }
+ if activeOwner != nil {
+ return nil, false, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ }
+ if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) {
+ hasCanonicalKey = true
+ if preferred == nil {
+ preferred = record
+ }
+ continue
+ }
+ if fallback == nil {
+ fallback = record
+ }
+ }
+ if preferred != nil {
+ return preferred, hasCanonicalKey, nil
+ }
+ return fallback, hasCanonicalKey, nil
+}
+
+func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64) (*dbent.User, error) {
+ if client == nil || userID <= 0 {
+ return nil, nil
+ }
+ userEntity, err := client.User.Get(ctx, userID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
+ }
+ if !strings.EqualFold(strings.TrimSpace(userEntity.Status), service.StatusActive) {
+ return nil, service.ErrUserNotActive
+ }
+ return userEntity, nil
+}
+
+func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any {
+ if channel == nil {
+ return map[string]any{}
+ }
+ return cloneOAuthMetadata(channel.Metadata)
+}
+
+func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool {
+ if session == nil || decision == nil {
+ return false
+ }
+ switch strings.ToLower(strings.TrimSpace(session.Intent)) {
+ case "bind_current_user", "login", "adopt_existing_user_by_email":
+ return true
+ default:
+ return decision.AdoptDisplayName || decision.AdoptAvatar
+ }
+}
+
+func shouldSkipAvatarAdoption(err error) bool {
+ return errors.Is(err, service.ErrAvatarInvalid) ||
+ errors.Is(err, service.ErrAvatarTooLarge) ||
+ errors.Is(err, service.ErrAvatarNotImage)
+}
+
+func applyPendingOAuthBinding(
+ ctx context.Context,
+ client *dbent.Client,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+ forceBind bool,
+ applyFirstBindDefaults bool,
+) error {
+ if client == nil || session == nil {
+ return nil
+ }
+ if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
+ return nil
+ }
+
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return applyPendingOAuthBindingTx(ctx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults)
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := applyPendingOAuthBindingTx(txCtx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func applyPendingOAuthBindingTx(
+ ctx context.Context,
+ tx *dbent.Tx,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+ forceBind bool,
+ applyFirstBindDefaults bool,
+) error {
+ if tx == nil || session == nil {
+ return nil
+ }
+ if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
+ return nil
+ }
+
+ targetUserID := int64(0)
+ if overrideUserID != nil && *overrideUserID > 0 {
+ targetUserID = *overrideUserID
+ } else {
+ resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, tx.Client(), session)
+ if err != nil {
+ return err
+ }
+ targetUserID = resolvedUserID
+ }
+
+ adoptedDisplayName := ""
+ if decision != nil && decision.AdoptDisplayName {
+ adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
+ }
+ adoptedAvatarURL := ""
+ if decision != nil && decision.AdoptAvatar {
+ adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
+ }
+ shouldAdoptAvatar := false
+ if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
+ if err := service.ValidateUserAvatar(adoptedAvatarURL); err == nil {
+ shouldAdoptAvatar = true
+ } else if !shouldSkipAvatarAdoption(err) {
+ return err
+ }
+ }
+
+ if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
+ if err := tx.Client().User.UpdateOneID(targetUserID).
+ SetUsername(adoptedDisplayName).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
+ if err != nil {
+ return err
+ }
+
+ metadata := cloneOAuthMetadata(identity.Metadata)
+ for key, value := range session.UpstreamIdentityClaims {
+ metadata[key] = value
+ }
+ if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
+ metadata["display_name"] = adoptedDisplayName
+ }
+ if shouldAdoptAvatar {
+ metadata["avatar_url"] = adoptedAvatarURL
+ }
+
+ updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata)
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ if _, err := updateIdentity.Save(ctx); err != nil {
+ return err
+ }
+
+ if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
+ if _, err := tx.Client().IdentityAdoptionDecision.Update().
+ Where(
+ identityadoptiondecision.IdentityIDEQ(identity.ID),
+ identityadoptiondecision.IDNEQ(decision.ID),
+ ).
+ ClearIdentityID().
+ Save(ctx); err != nil {
+ return err
+ }
+ if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
+ SetIdentityID(identity.ID).
+ Save(ctx); err != nil {
+ return err
+ }
+ }
+
+ if applyFirstBindDefaults && authService != nil {
+ if err := authService.ApplyProviderDefaultSettingsOnFirstBind(ctx, targetUserID, session.ProviderType); err != nil {
+ return err
+ }
+ }
+
+ if shouldAdoptAvatar && userService != nil {
+ if _, err := userService.SetAvatar(ctx, targetUserID, adoptedAvatarURL); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func consumePendingOAuthBrowserSessionTx(
+ ctx context.Context,
+ tx *dbent.Tx,
+ session *dbent.PendingAuthSession,
+) error {
+ if tx == nil || session == nil {
+ return service.ErrPendingAuthSessionNotFound
+ }
+
+ storedSession, err := tx.Client().PendingAuthSession.Get(ctx, session.ID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return service.ErrPendingAuthSessionNotFound
+ }
+ return err
+ }
+
+ now := time.Now().UTC()
+ if storedSession.ConsumedAt != nil {
+ return service.ErrPendingAuthSessionConsumed
+ }
+ if !storedSession.ExpiresAt.IsZero() && now.After(storedSession.ExpiresAt) {
+ return service.ErrPendingAuthSessionExpired
+ }
+ if strings.TrimSpace(storedSession.BrowserSessionKey) != "" &&
+ strings.TrimSpace(storedSession.BrowserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
+ return service.ErrPendingAuthBrowserMismatch
+ }
+
+ if _, err := tx.Client().PendingAuthSession.UpdateOneID(storedSession.ID).
+ SetConsumedAt(now).
+ SetCompletionCodeHash("").
+ ClearCompletionCodeExpiresAt().
+ Save(ctx); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func applyPendingOAuthAdoptionAndConsumeSession(
+ ctx context.Context,
+ client *dbent.Client,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ userID int64,
+) error {
+ if client == nil {
+ return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ if session == nil || userID <= 0 {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := applyPendingOAuthAdoption(txCtx, client, authService, userService, session, decision, &userID); err != nil {
+ return err
+ }
+ if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func applyPendingOAuthAdoption(
+ ctx context.Context,
+ client *dbent.Client,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+) error {
+ return applyPendingOAuthBinding(
+ ctx,
+ client,
+ authService,
+ userService,
+ session,
+ decision,
+ overrideUserID,
+ false,
+ strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user"),
+ )
+}
+
+func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
+ if len(payload) == 0 || len(upstream) == 0 {
+ return
+ }
+
+ displayName := pendingSessionStringValue(upstream, "suggested_display_name")
+ avatarURL := pendingSessionStringValue(upstream, "suggested_avatar_url")
+
+ if displayName != "" {
+ if _, exists := payload["suggested_display_name"]; !exists {
+ payload["suggested_display_name"] = displayName
+ }
+ }
+ if avatarURL != "" {
+ if _, exists := payload["suggested_avatar_url"]; !exists {
+ payload["suggested_avatar_url"] = avatarURL
+ }
+ }
+ if displayName != "" || avatarURL != "" {
+ payload["adoption_required"] = true
+ }
+}
+
+func pendingOAuthIdentityExistsForUser(
+ ctx context.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ userID int64,
+) (bool, error) {
+ if client == nil || session == nil || userID <= 0 {
+ return false, nil
+ }
+
+ providerType := strings.TrimSpace(session.ProviderType)
+ providerKey := strings.TrimSpace(session.ProviderKey)
+ providerSubject := strings.TrimSpace(session.ProviderSubject)
+ if providerType == "" || providerSubject == "" {
+ return false, nil
+ }
+
+ query := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ authidentity.UserIDEQ(userID),
+ )
+ if strings.EqualFold(providerType, "wechat") {
+ query = query.Where(authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(providerKey)...))
+ } else if providerKey != "" {
+ query = query.Where(authidentity.ProviderKeyEQ(providerKey))
+ }
+
+ count, err := query.Count(ctx)
+ if err != nil {
+ return false, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ return count > 0, nil
+}
+
+func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt(
+ ctx context.Context,
+ session *dbent.PendingAuthSession,
+ payload map[string]any,
+) (bool, error) {
+ if session == nil || len(payload) == 0 {
+ return false, nil
+ }
+ if !pendingOAuthCompletionCanIssueTokenPair(session, payload) {
+ return false, nil
+ }
+ if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" &&
+ pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") == "" {
+ return false, nil
+ }
+
+ return pendingOAuthIdentityExistsForUser(ctx, h.entClient(), session, *session.TargetUserID)
+}
+
+func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) {
+ secureCookie := isRequestHTTPS(c)
+ clearCookies := func() {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ clearCookies()
+ return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ clearCookies()
+ return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ clearCookies()
+ return nil, nil, clearCookies, err
+ }
+
+ session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearCookies()
+ return nil, nil, clearCookies, err
+ }
+
+ return svc, session, clearCookies, nil
+}
+
+func (h *AuthHandler) consumePendingOAuthSessionOnLogout(c *gin.Context) {
+ if c == nil || c.Request == nil {
+ return
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ return
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return
+ }
+ _, _ = svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+}
+
+func clearOAuthLogoutCookies(c *gin.Context) {
+ secureCookie := isRequestHTTPS(c)
+
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ clearOAuthBindAccessTokenCookie(c, secureCookie)
+
+ clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
+ clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
+ clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
+ clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
+ clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
+
+ oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie)
+ oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie)
+ oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
+
+ wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
+
+ wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie)
+}
+
+func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
+ completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
+ payload := gin.H{
+ "auth_result": "pending_session",
+ "provider": strings.TrimSpace(session.ProviderType),
+ "intent": strings.TrimSpace(session.Intent),
+ }
+ for key, value := range completionResponse {
+ payload[key] = value
+ }
+ if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
+ payload["email"] = email
+ }
+ return payload
+}
+
+func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any {
+ normalized := clonePendingMap(payload)
+ for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
+ delete(normalized, key)
+ }
+ step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
+ switch step {
+ case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
+ normalized["step"] = oauthPendingChoiceStep
+ }
+ if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(normalized, "step")), oauthPendingChoiceStep) {
+ normalized["adoption_required"] = true
+ }
+ if _, exists := normalized["adoption_required"]; !exists {
+ if _, hasChoiceFields := normalized["email_binding_required"]; hasChoiceFields {
+ normalized["adoption_required"] = true
+ }
+ }
+ return normalized
+}
+
+func pendingOAuthChoiceCompletionResponse(session *dbent.PendingAuthSession, email string) map[string]any {
+ response := mergePendingCompletionResponse(session, map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "force_email_on_signup": true,
+ "email_binding_required": true,
+ "existing_account_bindable": true,
+ })
+ if email = strings.TrimSpace(email); email != "" {
+ response["email"] = email
+ response["resolved_email"] = email
+ }
+ return response
+}
+
+func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState(
+ c *gin.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ targetUser *dbent.User,
+ email string,
+) (*dbent.PendingAuthSession, error) {
+ completionResponse := pendingOAuthChoiceCompletionResponse(session, email)
+ var targetUserID *int64
+ if targetUser != nil && targetUser.ID > 0 {
+ targetUserID = &targetUser.ID
+ }
+ session, err := updatePendingOAuthSessionProgress(
+ c.Request.Context(),
+ client,
+ session,
+ strings.TrimSpace(session.Intent),
+ email,
+ targetUserID,
+ completionResponse,
+ )
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
+ }
+ return session, nil
+}
+
+func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) {
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ })
+}
+
+func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
+ var req bindPendingOAuthLoginRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
+ response.BadRequest(c, "Pending oauth session provider mismatch")
+ return
+ }
+
+ user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID {
+ response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user"))
+ return
+ }
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
+ tempToken, err := h.totpService.CreatePendingOAuthBindLoginSession(
+ c.Request.Context(),
+ user.ID,
+ user.Email,
+ session.SessionToken,
+ session.BrowserSessionKey,
+ )
+ if err != nil {
+ response.InternalError(c, "Failed to create 2FA session")
+ return
+ }
+ response.Success(c, TotpLoginResponse{
+ Requires2FA: true,
+ TempToken: tempToken,
+ UserEmailMasked: service.MaskEmail(user.Email),
+ })
+ return
+ }
+ if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID, true, true); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
+ if err != nil {
+ response.InternalError(c, "Failed to generate token pair")
+ return
+ }
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ clearCookies()
+ writeOAuthTokenPairResponse(c, tokenPair)
+}
+
+func respondPendingOAuthBindingApplyError(c *gin.Context, err error) {
+ if code := infraerrors.Code(err); code >= http.StatusBadRequest && code < http.StatusInternalServerError {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+}
+
+func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) {
+ var req createPendingOAuthAccountRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ _, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
+ response.BadRequest(c, "Pending oauth session provider mismatch")
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+
+ email := strings.TrimSpace(strings.ToLower(req.Email))
+ existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email)
+ if err != nil {
+ switch {
+ case errors.Is(err, service.ErrUserNotFound):
+ existingUser = nil
+ case infraerrors.Code(err) >= http.StatusBadRequest && infraerrors.Code(err) < http.StatusInternalServerError:
+ response.ErrorFrom(c, err)
+ return
+ default:
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable"))
+ return
+ }
+ }
+ if existingUser != nil {
+ session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
+ c.Request.Context(),
+ email,
+ req.Password,
+ strings.TrimSpace(req.VerifyCode),
+ strings.TrimSpace(req.InvitationCode),
+ strings.TrimSpace(session.ProviderType),
+ )
+ if err != nil {
+ if errors.Is(err, service.ErrEmailExists) {
+ existingUser, lookupErr := findUserByNormalizedEmail(c.Request.Context(), client, email)
+ if lookupErr != nil {
+ response.ErrorFrom(c, lookupErr)
+ return
+ }
+ session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ rollbackCreatedUser := func(originalErr error) bool {
+ if user == nil || user.ID <= 0 {
+ return false
+ }
+ if rollbackErr := h.authService.RollbackOAuthEmailAccountCreation(
+ c.Request.Context(),
+ user.ID,
+ strings.TrimSpace(req.InvitationCode),
+ ); rollbackErr != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer(
+ "PENDING_AUTH_ACCOUNT_ROLLBACK_FAILED",
+ "failed to rollback pending oauth account creation",
+ ).WithCause(fmt.Errorf("original error: %w; rollback error: %v", originalErr, rollbackErr)))
+ return true
+ }
+ user = nil
+ return false
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
+ if err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ tx, err := client.Tx(c.Request.Context())
+ if err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+ defer func() { _ = tx.Rollback() }()
+ txCtx := dbent.NewTxContext(c.Request.Context(), tx)
+
+ if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+
+ if err := h.authService.FinalizeOAuthEmailAccount(
+ txCtx,
+ user,
+ strings.TrimSpace(req.InvitationCode),
+ strings.TrimSpace(session.ProviderType),
+ strings.TrimSpace(req.AffCode),
+ ); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if pendingOAuthCreateAccountPreCommitHook != nil {
+ if err := pendingOAuthCreateAccountPreCommitHook(txCtx, session); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ }
+
+ if err := tx.Commit(); err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ clearCookies()
+ writeOAuthTokenPairResponse(c, tokenPair)
+}
+
+// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload.
+// POST /api/v1/auth/oauth/pending/exchange
+func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
+ secureCookie := isRequestHTTPS(c)
+ clearCookies := func() {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ }
+ adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c)
+ if err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ clearCookies()
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ clearCookies()
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ payload, ok := readCompletionResponse(session.LocalFlowState)
+ if !ok {
+ clearCookies()
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
+ return
+ }
+ payload = normalizePendingOAuthCompletionResponse(payload)
+ if strings.TrimSpace(session.RedirectTo) != "" {
+ if _, exists := payload["redirect"]; !exists {
+ payload["redirect"] = session.RedirectTo
+ }
+ }
+ applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
+
+ canIssueTokenPair := pendingOAuthCompletionCanIssueTokenPair(session, payload)
+ var loginUser *service.User
+ if canIssueTokenPair {
+ loginUser, err = h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensureLoginUserActive(loginUser); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+ skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload)
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ if skipAdoptionPrompt {
+ delete(payload, "adoption_required")
+ }
+
+ if pendingSessionWantsInvitation(payload) {
+ if adoptionDecision.hasDecision() {
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ _ = decision
+ }
+ response.Success(c, payload)
+ return
+ }
+ if !adoptionDecision.hasDecision() {
+ adoptionRequired, _ := payload["adoption_required"].(bool)
+ if adoptionRequired {
+ response.Success(c, payload)
+ return
+ }
+ }
+
+ decisionReq := adoptionDecision
+ if !decisionReq.hasDecision() {
+ adoptDisplayName := false
+ adoptAvatar := false
+ decisionReq = oauthAdoptionDecisionRequest{
+ AdoptDisplayName: &adoptDisplayName,
+ AdoptAvatar: &adoptAvatar,
+ }
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, decisionReq)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, session.TargetUserID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
+
+ if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if canIssueTokenPair {
+ tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), loginUser, "")
+ if err != nil {
+ clearCookies()
+ response.InternalError(c, "Failed to generate token pair")
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), loginUser.ID)
+ payload["access_token"] = tokenPair.AccessToken
+ payload["refresh_token"] = tokenPair.RefreshToken
+ payload["expires_in"] = tokenPair.ExpiresIn
+ payload["token_type"] = "Bearer"
+ }
+
+ clearCookies()
+ response.Success(c, payload)
+}
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
new file mode 100644
index 00000000..ffe9ff5f
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -0,0 +1,2996 @@
+package handler
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/redeemcode"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/pquerna/otp/totp"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func TestApplySuggestedProfileToCompletionResponse(t *testing.T) {
+ payload := map[string]any{
+ "access_token": "token",
+ }
+ upstream := map[string]any{
+ "suggested_display_name": "Alice",
+ "suggested_avatar_url": "https://cdn.example/avatar.png",
+ }
+
+ applySuggestedProfileToCompletionResponse(payload, upstream)
+
+ require.Equal(t, "Alice", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
+ require.Equal(t, true, payload["adoption_required"])
+}
+
+func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *testing.T) {
+ payload := map[string]any{
+ "suggested_display_name": "Existing",
+ "adoption_required": false,
+ }
+ upstream := map[string]any{
+ "suggested_display_name": "Alice",
+ "suggested_avatar_url": "https://cdn.example/avatar.png",
+ }
+
+ applySuggestedProfileToCompletionResponse(payload, upstream)
+
+ require.Equal(t, "Existing", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
+ require.Equal(t, true, payload["adoption_required"])
+}
+
+func TestSetOAuthPendingSessionCookieUsesProviderCompletionPathPrefix(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ ginCtx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
+
+ setOAuthPendingSessionCookie(ginCtx, "pending-session-token", false)
+
+ cookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, cookie)
+ require.Equal(t, "/api/v1/auth/oauth", cookie.Path)
+}
+
+func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("linuxdo-123@linuxdo-connect.invalid").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Alice Example",
+ "suggested_avatar_url": "https://cdn.example/alice.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previewRecorder := httptest.NewRecorder()
+ previewCtx, _ := gin.CreateTestContext(previewRecorder)
+ previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
+ previewCtx.Request = previewReq
+
+ handler.ExchangePendingOAuthCompletion(previewCtx)
+
+ require.Equal(t, http.StatusOK, previewRecorder.Code)
+ previewData := decodeJSONResponseData(t, previewRecorder)
+ require.Equal(t, "Alice Example", previewData["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"])
+ require.Equal(t, true, previewData["adoption_required"])
+
+ storedUser, err := client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "legacy-name", storedUser.Username)
+
+ previewSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, previewSession.ConsumedAt)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
+ finalizeRecorder := httptest.NewRecorder()
+ finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder)
+ finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ finalizeReq.Header.Set("Content-Type", "application/json")
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
+ finalizeCtx.Request = finalizeReq
+
+ handler.ExchangePendingOAuthCompletion(finalizeCtx)
+
+ require.Equal(t, http.StatusOK, finalizeRecorder.Code)
+
+ storedUser, err = client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "Alice Example", storedUser.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "Alice Example", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"])
+
+ avatar := loadUserAvatarRecord(t, client, userEntity.ID)
+ require.NotNil(t, avatar)
+ require.Equal(t, "remote_url", avatar.StorageProvider)
+ require.Equal(t, "https://cdn.example/alice.png", avatar.URL)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionSkipsInvalidAvatarAdoptionWithoutBlockingCompletion(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("invalid-avatar@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-invalid-avatar-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("invalid-avatar-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("browser-invalid-avatar-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Alice Example",
+ "suggested_avatar_url": "/avatars/alice.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-invalid-avatar-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("invalid-avatar-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "Alice Example", identity.Metadata["display_name"])
+ _, hasAdoptedAvatar := identity.Metadata["avatar_url"]
+ require.False(t, hasAdoptedAvatar)
+
+ avatar := loadUserAvatarRecord(t, client, userEntity.ID)
+ require.Nil(t, avatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionBindCurrentUserPreviewThenFinalizeBindsIdentityWithoutAdoption(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("bind-target@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-pending-session-token").
+ SetIntent("bind_current_user").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("bind-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("bind-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Bound Example",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/settings/profile",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previewRecorder := httptest.NewRecorder()
+ previewCtx, _ := gin.CreateTestContext(previewRecorder)
+ previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")})
+ previewCtx.Request = previewReq
+
+ handler.ExchangePendingOAuthCompletion(previewCtx)
+
+ require.Equal(t, http.StatusOK, previewRecorder.Code)
+ previewData := decodeJSONResponseData(t, previewRecorder)
+ require.Equal(t, "Bound Example", previewData["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/bound.png", previewData["suggested_avatar_url"])
+ require.Equal(t, true, previewData["adoption_required"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("bind-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ previewSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, previewSession.ConsumedAt)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ finalizeRecorder := httptest.NewRecorder()
+ finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder)
+ finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ finalizeReq.Header.Set("Content-Type", "application/json")
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")})
+ finalizeCtx.Request = finalizeReq
+
+ handler.ExchangePendingOAuthCompletion(finalizeCtx)
+
+ require.Equal(t, http.StatusOK, finalizeRecorder.Code)
+
+ storedUser, err := client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "legacy-name", storedUser.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("bind-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "Bound Example", identity.Metadata["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/bound.png", identity.Metadata["suggested_avatar_url"])
+ _, hasDisplayName := identity.Metadata["display_name"]
+ require.False(t, hasDisplayName)
+ _, hasAvatarURL := identity.Metadata["avatar_url"]
+ require.False(t, hasAvatarURL)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ targetUser, err := client.User.Create().
+ SetEmail("bind-conflict-target@example.com").
+ SetUsername("target-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ ownerUser, err := client.User.Create().
+ SetEmail("bind-conflict-owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ existingIdentity, err := client.AuthIdentity.Create().
+ SetUserID(ownerUser.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("conflict-123").
+ SetMetadata(map[string]any{"username": "owner-user"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-conflict-session-token").
+ SetIntent("bind_current_user").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("conflict-123").
+ SetTargetUserID(targetUser.ID).
+ SetResolvedEmail(targetUser.Email).
+ SetBrowserSessionKey("bind-conflict-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Conflict Example",
+ "suggested_avatar_url": "https://cdn.example/conflict.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-conflict-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusInternalServerError, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "PENDING_AUTH_ADOPTION_APPLY_FAILED", payload["reason"])
+
+ identity, err := client.AuthIdentity.Get(ctx, existingIdentity.ID)
+ require.NoError(t, err)
+ require.Equal(t, ownerUser.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdoption(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("login-false@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-false-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-false-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-false-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Login Example",
+ "suggested_avatar_url": "https://cdn.example/login.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-false-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("login-false-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionLoginReassignsExistingDecisionIdentityReference(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("login-reassign@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ existingIdentity, err := client.AuthIdentity.Create().
+ SetUserID(userEntity.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-reassign-123").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previousSession, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-reassign-previous-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-reassign-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-reassign-previous-browser-session-key").
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "previous-access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previousDecision, err := client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(previousSession.ID).
+ SetIdentityID(existingIdentity.ID).
+ SetAdoptDisplayName(true).
+ SetAdoptAvatar(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-reassign-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-reassign-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-reassign-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Login Reassign",
+ "suggested_avatar_url": "https://cdn.example/login-reassign.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(session.ID).
+ SetAdoptDisplayName(false).
+ SetAdoptAvatar(false).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-reassign-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ reloadedPrevious, err := client.IdentityAdoptionDecision.Get(ctx, previousDecision.ID)
+ require.NoError(t, err)
+ require.Nil(t, reloadedPrevious.IdentityID)
+
+ currentDecision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, currentDecision.IdentityID)
+ require.Equal(t, existingIdentity.ID, *currentDecision.IdentityID)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionLoginWithoutDecisionStillBindsIdentity(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("login-nodecision@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-nodecision-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-nodecision-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-nodecision-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "login-nodecision-user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-nodecision-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("login-nodecision-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdoptionPrompt(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("existing-login@example.com").
+ SetUsername("existing-login-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(userEntity.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("existing-login-123").
+ SetMetadata(map[string]any{
+ "username": "existing-login-user",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-login-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("existing-login-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("existing-login-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Existing Login Example",
+ "suggested_avatar_url": "https://cdn.example/existing-login.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "legacy-access-token",
+ "refresh_token": "legacy-refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-login-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ payload := decodeJSONResponseData(t, recorder)
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.NotEqual(t, "legacy-access-token", payload["access_token"])
+ require.NotEqual(t, "legacy-refresh-token", payload["refresh_token"])
+ require.Equal(t, "/dashboard", payload["redirect"])
+ require.Equal(t, "Existing Login Example", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"])
+ require.NotContains(t, payload, "adoption_required")
+
+ accessToken, ok := payload["access_token"].(string)
+ require.True(t, ok)
+ claims, err := handler.authService.ValidateToken(accessToken)
+ require.NoError(t, err)
+ reloadedUser, err := handler.userService.GetByID(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion)
+
+ decisionCount, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, decisionCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+
+ completion, ok := storedSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.NotContains(t, completion, "access_token")
+ require.NotContains(t, completion, "refresh_token")
+ require.NotContains(t, completion, "expires_in")
+ require.NotContains(t, completion, "token_type")
+ require.Equal(t, "/dashboard", completion["redirect"])
+}
+
+func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyBackendModeEnabled: "true",
+ },
+ })
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("blocked@example.com").
+ SetUsername("blocked-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("blocked-backend-mode-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("blocked-subject-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("blocked-backend-mode-browser-session-key").
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "refresh_token": "refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("blocked-backend-mode-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionRejectsDisabledTargetUser(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("disabled-linked@example.com").
+ SetUsername("disabled-linked-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusDisabled).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("disabled-linked-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("disabled-linked-subject").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("disabled-linked-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Disabled Linked User",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("disabled-linked-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) {
+ payload := normalizePendingOAuthCompletionResponse(map[string]any{
+ "access_token": "legacy-access-token",
+ "refresh_token": "legacy-refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ "redirect": "/dashboard",
+ })
+
+ require.NotContains(t, payload, "access_token")
+ require.NotContains(t, payload, "refresh_token")
+ require.NotContains(t, payload, "expires_in")
+ require.NotContains(t, payload, "token_type")
+ require.Equal(t, "/dashboard", payload["redirect"])
+}
+
+func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, true)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("invitation-required-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("invitation-123").
+ SetBrowserSessionKey("invitation-required-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Invite Example",
+ "suggested_avatar_url": "https://cdn.example/invite.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "error": "invitation_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("invitation-required-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ data := decodeJSONResponseData(t, recorder)
+ require.Equal(t, "invitation_required", data["error"])
+ require.Equal(t, true, data["adoption_required"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("invitation-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810")
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-create-123").
+ SetBrowserSessionKey("create-account-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Fresh OIDC User",
+ "suggested_avatar_url": "https://cdn.example/fresh.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.Equal(t, "Bearer", payload["token_type"])
+
+ createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, service.StatusActive, createdUser.Status)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-create-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, createdUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-123").
+ SetBrowserSessionKey("existing-email-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Existing OIDC User",
+ "suggested_avatar_url": "https://cdn.example/existing.png",
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, "pending_session", payload["auth_result"])
+ require.Equal(t, oauthIntentLogin, payload["intent"])
+ require.Equal(t, "oidc", payload["provider"])
+ require.Equal(t, "/dashboard", payload["redirect"])
+ require.Equal(t, true, payload["adoption_required"])
+ require.Equal(t, oauthPendingChoiceStep, payload["step"])
+ require.Equal(t, "owner@example.com", payload["email"])
+ require.Equal(t, "Existing OIDC User", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, storedSession.Intent)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+ require.Nil(t, storedSession.ConsumedAt)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-existing-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+}
+
+func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail(" Owner@Example.com ").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-normalized-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-normalized-123").
+ SetBrowserSessionKey("existing-email-normalized-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Existing OIDC User",
+ "suggested_avatar_url": "https://cdn.example/existing.png",
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-normalized-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, oauthIntentLogin, payload["intent"])
+ require.Equal(t, oauthPendingChoiceStep, payload["step"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+}
+
+func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-send-code-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-send-code-123").
+ SetBrowserSessionKey("existing-email-send-code-browser-session-key").
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "email_required",
+ },
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/send-verify-code", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-send-code-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.SendPendingOAuthVerifyCode(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, "pending_session", payload["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, payload["step"])
+ require.Equal(t, "owner@example.com", payload["email"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, storedSession.Intent)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+}
+
+func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ emailVerifyEnabled: true,
+ emailCache: &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ "fresh@example.com": {
+ Code: "246810",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ },
+ settingValues: map[string]string{
+ service.SettingKeyBackendModeEnabled: "true",
+ },
+ })
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-backend-mode-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-create-backend-mode-123").
+ SetBrowserSessionKey("create-account-backend-mode-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-backend-mode-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestLogoutClearsPendingOAuthAndBindCookies(t *testing.T) {
+ handler, _ := newOAuthPendingFlowTestHandler(t, false)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", bytes.NewBufferString(`{}`))
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue("pending-session-token")})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("pending-browser-key")})
+ req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-token"})
+ ginCtx.Request = req
+
+ handler.Logout(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName).MaxAge)
+ require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingBrowserCookieName).MaxAge)
+ require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName).MaxAge)
+}
+
+func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810")
+ ctx := context.Background()
+
+ conflictOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(conflictOwner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-conflict-123").
+ SetMetadata(map[string]any{
+ "username": "owner-user",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ invitation, err := client.RedeemCode.Create().
+ SetCode("INVITE123").
+ SetType(service.RedeemTypeInvitation).
+ SetStatus(service.StatusUnused).
+ SetValue(0).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-conflict-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-conflict-123").
+ SetBrowserSessionKey("create-account-conflict-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","invitation_code":"INVITE123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-conflict-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusConflict, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedInvitation, err := client.RedeemCode.Get(ctx, invitation.ID)
+ require.NoError(t, err)
+ require.Equal(t, service.StatusUnused, storedInvitation.Status)
+ require.Nil(t, storedInvitation.UsedBy)
+ require.Nil(t, storedInvitation.UsedAt)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCreateOIDCOAuthAccountRollsBackPostBindFailureBeforeIdentityCanCommit(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ emailVerifyEnabled: true,
+ emailCache: &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ "fresh@example.com": {
+ Code: "246810",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ },
+ userRepoOptions: oauthPendingFlowUserRepoOptions{
+ rejectDeleteWhileAuthIdentityExists: true,
+ },
+ })
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-finalize-failure-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-finalize-failure-123").
+ SetBrowserSessionKey("create-account-finalize-failure-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ pendingOAuthCreateAccountPreCommitHook = func(context.Context, *dbent.PendingAuthSession) error {
+ return errors.New("forced post-bind failure")
+ }
+ t.Cleanup(func() {
+ pendingOAuthCreateAccountPreCommitHook = nil
+ })
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-finalize-failure-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusInternalServerError, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-finalize-failure-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.Equal(t, "Bearer", payload["token_type"])
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, existingUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginBlocksBackendModeBeforeTokenIssue(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyBackendModeEnabled: "true",
+ },
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-backend-mode-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-backend-mode-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-backend-mode-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-backend-mode-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-backend-mode-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-invalid-password-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-invalid-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-invalid-password-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusUnauthorized, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "INVALID_CREDENTIALS", payload["reason"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginReclaimsIdentityOwnedBySoftDeletedUser(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ oldOwnerHash, err := handler.authService.HashPassword("old-secret")
+ require.NoError(t, err)
+ oldOwner, err := client.User.Create().
+ SetEmail("old-owner@example.com").
+ SetUsername("old-owner").
+ SetPasswordHash(oldOwnerHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(oldOwner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-soft-deleted-123").
+ SetMetadata(map[string]any{"username": "old-owner"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.User.Delete().Where(dbuser.IDEQ(oldOwner.ID)).Exec(ctx)
+ require.NoError(t, err)
+
+ newOwnerHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+ newOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(newOwnerHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-soft-deleted-owner-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-soft-deleted-123").
+ SetTargetUserID(newOwner.ID).
+ SetResolvedEmail(newOwner.Email).
+ SetBrowserSessionKey("bind-login-soft-deleted-owner-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Recovered OIDC User",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-soft-deleted-owner-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err = client.AuthIdentity.Get(ctx, identity.ID)
+ require.NoError(t, err)
+ require.Equal(t, newOwner.ID, identity.UserID)
+}
+
+func TestBindOIDCOAuthLoginAppliesFirstBindGrantOnce(t *testing.T) {
+ defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{}
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyAuthSourceDefaultOIDCBalance: "12.5",
+ service.SettingKeyAuthSourceDefaultOIDCConcurrency: "3",
+ service.SettingKeyAuthSourceDefaultOIDCSubscriptions: `[{"group_id":101,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true",
+ },
+ defaultSubAssigner: defaultSubAssigner,
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ firstSession, err := client.PendingAuthSession.Create().
+ SetSessionToken("first-bind-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-first-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("first-bind-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ firstBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ firstRecorder := httptest.NewRecorder()
+ firstGinCtx, _ := gin.CreateTestContext(firstRecorder)
+ firstReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", firstBody)
+ firstReq.Header.Set("Content-Type", "application/json")
+ firstReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(firstSession.SessionToken)})
+ firstReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("first-bind-browser-session-key")})
+ firstGinCtx.Request = firstReq
+
+ handler.BindOIDCOAuthLogin(firstGinCtx)
+
+ require.Equal(t, http.StatusOK, firstRecorder.Code)
+
+ storedUser, err := client.User.Get(ctx, existingUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, 17.5, storedUser.Balance)
+ require.Equal(t, 5, storedUser.Concurrency)
+ require.Zero(t, storedUser.TotalRecharged)
+ require.Len(t, defaultSubAssigner.calls, 1)
+ require.Equal(t, int64(existingUser.ID), defaultSubAssigner.calls[0].UserID)
+ require.Equal(t, int64(101), defaultSubAssigner.calls[0].GroupID)
+ require.Equal(t, 30, defaultSubAssigner.calls[0].ValidityDays)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
+
+ secondSession, err := client.PendingAuthSession.Create().
+ SetSessionToken("second-bind-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-second-456").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("second-bind-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Second OIDC User",
+ "suggested_avatar_url": "https://cdn.example/second.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ secondBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ secondRecorder := httptest.NewRecorder()
+ secondGinCtx, _ := gin.CreateTestContext(secondRecorder)
+ secondReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", secondBody)
+ secondReq.Header.Set("Content-Type", "application/json")
+ secondReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(secondSession.SessionToken)})
+ secondReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("second-bind-browser-session-key")})
+ secondGinCtx.Request = secondReq
+
+ handler.BindOIDCOAuthLogin(secondGinCtx)
+
+ require.Equal(t, http.StatusOK, secondRecorder.Code)
+
+ storedUser, err = client.User.Get(ctx, existingUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, 17.5, storedUser.Balance)
+ require.Equal(t, 5, storedUser.Concurrency)
+ require.Zero(t, storedUser.TotalRecharged)
+ require.Len(t, defaultSubAssigner.calls, 1)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
+}
+
+func TestResolvePendingOAuthTargetUserIDNormalizesLegacySpacingAndCase(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ _ = handler
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail(" Owner@Example.com ").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("resolve-target-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-target-123").
+ SetResolvedEmail("owner@example.com").
+ SetBrowserSessionKey("resolve-target-browser-session-key").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
+ require.NoError(t, err)
+ require.Equal(t, existingUser.ID, resolvedUserID)
+}
+
+func TestBindOIDCOAuthLoginReturns2FAChallengeWhenUserHasTotp(t *testing.T) {
+ totpCache := &oauthPendingFlowTotpCacheStub{}
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyTotpEnabled: "true",
+ },
+ totpCache: totpCache,
+ totpEncryptor: oauthPendingFlowTotpEncryptorStub{},
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+ totpEnabledAt := time.Now().UTC().Add(-time.Hour)
+ secret := "JBSWY3DPEHPK3PXP"
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ SetTotpEnabled(true).
+ SetTotpSecretEncrypted(secret).
+ SetTotpEnabledAt(totpEnabledAt).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-2fa-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-2fa-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-2fa-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-2fa-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ data := decodeJSONResponseData(t, recorder)
+ require.Equal(t, true, data["requires_2fa"])
+ require.Equal(t, "o***r@example.com", data["user_email_masked"])
+ tempToken, ok := data["temp_token"].(string)
+ require.True(t, ok)
+ require.NotEmpty(t, tempToken)
+
+ loginSession, err := totpCache.GetLoginSession(ctx, tempToken)
+ require.NoError(t, err)
+ require.NotNil(t, loginSession)
+ require.NotNil(t, loginSession.PendingOAuthBind)
+ require.Equal(t, session.SessionToken, loginSession.PendingOAuthBind.PendingSessionToken)
+ require.Equal(t, session.BrowserSessionKey, loginSession.PendingOAuthBind.BrowserSessionKey)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-2fa-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) {
+ totpCache := &oauthPendingFlowTotpCacheStub{}
+ defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{}
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyTotpEnabled: "true",
+ service.SettingKeyAuthSourceDefaultOIDCBalance: "8",
+ service.SettingKeyAuthSourceDefaultOIDCConcurrency: "2",
+ service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true",
+ },
+ defaultSubAssigner: defaultSubAssigner,
+ totpCache: totpCache,
+ totpEncryptor: oauthPendingFlowTotpEncryptorStub{},
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+ totpEnabledAt := time.Now().UTC().Add(-time.Hour)
+ secret := "JBSWY3DPEHPK3PXP"
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(4).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ SetTotpEnabled(true).
+ SetTotpSecretEncrypted(secret).
+ SetTotpEnabledAt(totpEnabledAt).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-2fa-pending-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-login-2fa-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("login-2fa-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(session.ID).
+ SetAdoptDisplayName(false).
+ SetAdoptAvatar(false).
+ Save(ctx)
+ require.NoError(t, err)
+
+ tempToken, err := handler.totpService.CreatePendingOAuthBindLoginSession(
+ ctx,
+ existingUser.ID,
+ existingUser.Email,
+ session.SessionToken,
+ session.BrowserSessionKey,
+ )
+ require.NoError(t, err)
+
+ code, err := totp.GenerateCode(secret, time.Now().UTC())
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"temp_token":"` + tempToken + `","totp_code":"` + code + `"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/2fa", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue(session.BrowserSessionKey)})
+ ginCtx.Request = req
+
+ handler.Login2FA(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ payload := decodeJSONResponseData(t, recorder)
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ accessToken, ok := payload["access_token"].(string)
+ require.True(t, ok)
+ claims, err := handler.authService.ValidateToken(accessToken)
+ require.NoError(t, err)
+ reloadedUser, err := handler.userService.GetByID(ctx, existingUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-login-2fa-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, existingUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+
+ loginSession, err := totpCache.GetLoginSession(ctx, tempToken)
+ require.NoError(t, err)
+ require.Nil(t, loginSession)
+
+ storedUser, err := client.User.Get(ctx, existingUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, 9.5, storedUser.Balance)
+ require.Equal(t, 6, storedUser.Concurrency)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
+ require.Empty(t, defaultSubAssigner.calls)
+}
+
+func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil)
+}
+
+func newOAuthPendingFlowTestHandlerWithEmailVerification(
+ t *testing.T,
+ invitationEnabled bool,
+ email string,
+ code string,
+) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ cache := &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ email: {
+ Code: code,
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ }
+ return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache)
+}
+
+func newOAuthPendingFlowTestHandlerWithOptions(
+ t *testing.T,
+ invitationEnabled bool,
+ emailVerifyEnabled bool,
+ emailCache service.EmailCache,
+) (*AuthHandler, *dbent.Client) {
+ return newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ invitationEnabled: invitationEnabled,
+ emailVerifyEnabled: emailVerifyEnabled,
+ emailCache: emailCache,
+ })
+}
+
+type oauthPendingFlowTestHandlerOptions struct {
+ invitationEnabled bool
+ emailVerifyEnabled bool
+ emailCache service.EmailCache
+ settingValues map[string]string
+ defaultSubAssigner service.DefaultSubscriptionAssigner
+ totpCache service.TotpCache
+ totpEncryptor service.SecretEncryptor
+ userRepoOptions oauthPendingFlowUserRepoOptions
+}
+
+func newOAuthPendingFlowTestHandlerWithDependencies(
+ t *testing.T,
+ options oauthPendingFlowTestHandlerOptions,
+) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ provider_type TEXT NOT NULL,
+ grant_reason TEXT NOT NULL DEFAULT 'first_bind',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, provider_type, grant_reason)
+)`)
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_avatars (
+ user_id INTEGER PRIMARY KEY,
+ storage_provider TEXT NOT NULL,
+ storage_key TEXT NOT NULL DEFAULT '',
+ url TEXT NOT NULL,
+ content_type TEXT NOT NULL DEFAULT '',
+ byte_size INTEGER NOT NULL DEFAULT 0,
+ sha256 TEXT NOT NULL DEFAULT '',
+ updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+)`)
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ }
+ settingValues := map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
+ service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
+ }
+ for key, value := range options.settingValues {
+ settingValues[key] = value
+ }
+ settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
+ userRepo := &oauthPendingFlowUserRepo{
+ client: client,
+ options: options.userRepoOptions,
+ }
+ redeemRepo := &oauthPendingFlowRedeemCodeRepo{client: client}
+ var emailService *service.EmailService
+ if options.emailCache != nil {
+ emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
+ values: map[string]string{
+ service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
+ },
+ }, options.emailCache)
+ }
+ authSvc := service.NewAuthService(
+ client,
+ userRepo,
+ redeemRepo,
+ &oauthPendingFlowRefreshTokenCacheStub{},
+ cfg,
+ settingSvc,
+ emailService,
+ nil,
+ nil,
+ nil,
+ options.defaultSubAssigner,
+ nil,
+ )
+ userSvc := service.NewUserService(userRepo, nil, nil, nil)
+ var totpSvc *service.TotpService
+ if options.totpCache != nil || options.totpEncryptor != nil {
+ totpCache := options.totpCache
+ if totpCache == nil {
+ totpCache = &oauthPendingFlowTotpCacheStub{}
+ }
+ totpEncryptor := options.totpEncryptor
+ if totpEncryptor == nil {
+ totpEncryptor = oauthPendingFlowTotpEncryptorStub{}
+ }
+ totpSvc = service.NewTotpService(userRepo, totpEncryptor, totpCache, settingSvc, nil, nil)
+ }
+
+ return &AuthHandler{
+ authService: authSvc,
+ userService: userSvc,
+ settingSvc: settingSvc,
+ totpService: totpSvc,
+ }, client
+}
+
+func boolSettingValue(v bool) string {
+ if v {
+ return "true"
+ }
+ return "false"
+}
+
+func boolPtr(v bool) *bool {
+ return &v
+}
+
+type oauthPendingFlowSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ return nil, service.ErrSettingNotFound
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", service.ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ result[key] = value
+ }
+ }
+ return result, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ result := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ result[key] = value
+ }
+ return result, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error {
+ return nil
+}
+
+type oauthPendingFlowRefreshTokenCacheStub struct{}
+
+type oauthPendingFlowEmailCacheStub struct {
+ verificationCodes map[string]*service.VerificationCodeData
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) {
+ if s == nil || s.verificationCodes == nil {
+ return nil, nil
+ }
+ return s.verificationCodes[email], nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error {
+ if s.verificationCodes == nil {
+ s.verificationCodes = map[string]*service.VerificationCodeData{}
+ }
+ s.verificationCodes[email] = data
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error {
+ delete(s.verificationCodes, email)
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
+type oauthPendingFlowRedeemCodeRepo struct {
+ client *dbent.Client
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Create(context.Context, *service.RedeemCode) error {
+ panic("unexpected Create call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) CreateBatch(context.Context, []service.RedeemCode) error {
+ panic("unexpected CreateBatch call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) GetByID(context.Context, int64) (*service.RedeemCode, error) {
+ panic("unexpected GetByID call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
+ entity, err := r.client.RedeemCode.Query().Where(redeemcode.CodeEQ(code)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrRedeemCodeNotFound
+ }
+ return nil, err
+ }
+ notes := ""
+ if entity.Notes != nil {
+ notes = *entity.Notes
+ }
+ return &service.RedeemCode{
+ ID: entity.ID,
+ Code: entity.Code,
+ Type: entity.Type,
+ Value: entity.Value,
+ Status: entity.Status,
+ UsedBy: entity.UsedBy,
+ UsedAt: entity.UsedAt,
+ Notes: notes,
+ CreatedAt: entity.CreatedAt,
+ GroupID: entity.GroupID,
+ ValidityDays: entity.ValidityDays,
+ }, nil
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ update := r.client.RedeemCode.UpdateOneID(code.ID).
+ SetCode(code.Code).
+ SetType(code.Type).
+ SetValue(code.Value).
+ SetStatus(code.Status).
+ SetNotes(code.Notes).
+ SetValidityDays(code.ValidityDays)
+ if code.UsedBy != nil {
+ update = update.SetUsedBy(*code.UsedBy)
+ } else {
+ update = update.ClearUsedBy()
+ }
+ if code.UsedAt != nil {
+ update = update.SetUsedAt(*code.UsedAt)
+ } else {
+ update = update.ClearUsedAt()
+ }
+ if code.GroupID != nil {
+ update = update.SetGroupID(*code.GroupID)
+ } else {
+ update = update.ClearGroupID()
+ }
+ _, err := update.Save(ctx)
+ return err
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error {
+ panic("unexpected Delete call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error {
+ affected, err := r.client.RedeemCode.Update().
+ Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
+ SetStatus(service.StatusUsed).
+ SetUsedBy(userID).
+ SetUsedAt(time.Now().UTC()).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return service.ErrRedeemCodeUsed
+ }
+ return nil
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) List(context.Context, pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) ListByUser(context.Context, int64, int) ([]service.RedeemCode, error) {
+ panic("unexpected ListByUser call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListByUserPaginated call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
+ panic("unexpected SumPositiveBalanceByUser call")
+}
+
+func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
+ t.Helper()
+
+ var envelope struct {
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope))
+ return envelope.Data
+}
+
+func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
+ t.Helper()
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ return payload
+}
+
+type oauthPendingFlowAvatarRecord struct {
+ StorageProvider string
+ URL string
+}
+
+func loadUserAvatarRecord(t *testing.T, client *dbent.Client, userID int64) *oauthPendingFlowAvatarRecord {
+ t.Helper()
+
+ var rows entsql.Rows
+ err := client.Driver().Query(
+ context.Background(),
+ `SELECT storage_provider, url FROM user_avatars WHERE user_id = ?`,
+ []any{userID},
+ &rows,
+ )
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ require.NoError(t, rows.Err())
+ return nil
+ }
+
+ var record oauthPendingFlowAvatarRecord
+ require.NoError(t, rows.Scan(&record.StorageProvider, &record.URL))
+ require.NoError(t, rows.Err())
+ return &record
+}
+
+func countProviderGrantRecords(
+ t *testing.T,
+ client *dbent.Client,
+ userID int64,
+ providerType string,
+ grantReason string,
+) int {
+ t.Helper()
+
+ var rows entsql.Rows
+ err := client.Driver().Query(
+ context.Background(),
+ `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`,
+ []any{userID, providerType, grantReason},
+ &rows,
+ )
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+
+ require.True(t, rows.Next())
+ var count int
+ require.NoError(t, rows.Scan(&count))
+ require.False(t, rows.Next())
+ return count
+}
+
+type oauthPendingFlowUserRepo struct {
+ client *dbent.Client
+ options oauthPendingFlowUserRepoOptions
+}
+
+type oauthPendingFlowUserRepoOptions struct {
+ rejectDeleteWhileAuthIdentityExists bool
+}
+
+func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error {
+ entity, err := r.client.User.Create().
+ SetEmail(user.Email).
+ SetUsername(user.Username).
+ SetNotes(user.Notes).
+ SetPasswordHash(user.PasswordHash).
+ SetRole(user.Role).
+ SetBalance(user.Balance).
+ SetConcurrency(user.Concurrency).
+ SetStatus(user.Status).
+ SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted).
+ SetTotpEnabled(user.TotpEnabled).
+ SetNillableTotpEnabledAt(user.TotpEnabledAt).
+ SetTotalRecharged(user.TotalRecharged).
+ SetSignupSource(user.SignupSource).
+ SetNillableLastLoginAt(user.LastLoginAt).
+ SetNillableLastActiveAt(user.LastActiveAt).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ user.ID = entity.ID
+ user.CreatedAt = entity.CreatedAt
+ user.UpdatedAt = entity.UpdatedAt
+ return nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
+ entity, err := r.client.User.Get(ctx, id)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrUserNotFound
+ }
+ return nil, err
+ }
+ return oauthPendingFlowServiceUser(entity), nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
+ entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrUserNotFound
+ }
+ return nil, err
+ }
+ return oauthPendingFlowServiceUser(entity), nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) {
+ panic("unexpected GetFirstAdmin call")
+}
+
+func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error {
+ entity, err := r.client.User.UpdateOneID(user.ID).
+ SetEmail(user.Email).
+ SetUsername(user.Username).
+ SetNotes(user.Notes).
+ SetPasswordHash(user.PasswordHash).
+ SetRole(user.Role).
+ SetBalance(user.Balance).
+ SetConcurrency(user.Concurrency).
+ SetStatus(user.Status).
+ SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted).
+ SetTotpEnabled(user.TotpEnabled).
+ SetNillableTotpEnabledAt(user.TotpEnabledAt).
+ SetTotalRecharged(user.TotalRecharged).
+ SetSignupSource(user.SignupSource).
+ SetNillableLastLoginAt(user.LastLoginAt).
+ SetNillableLastActiveAt(user.LastActiveAt).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ user.UpdatedAt = entity.UpdatedAt
+ return nil
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ return r.client.User.UpdateOneID(userID).SetLastActiveAt(activeAt).Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
+ if r.options.rejectDeleteWhileAuthIdentityExists {
+ count, err := r.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(id)).Count(ctx)
+ if err != nil {
+ return err
+ }
+ if count > 0 {
+ return errors.New("cannot delete user while auth identities still exist")
+ }
+ }
+ return r.client.User.DeleteOneID(id).Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var rows entsql.Rows
+ if err := driver.Query(
+ ctx,
+ `SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 FROM user_avatars WHERE user_id = ?`,
+ []any{userID},
+ &rows,
+ ); err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ return nil, rows.Err()
+ }
+
+ var avatar service.UserAvatar
+ if err := rows.Scan(
+ &avatar.StorageProvider,
+ &avatar.StorageKey,
+ &avatar.URL,
+ &avatar.ContentType,
+ &avatar.ByteSize,
+ &avatar.SHA256,
+ ); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return &avatar, nil
+}
+
+func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var result entsql.Result
+ if err := driver.Exec(
+ ctx,
+ `INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
+VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
+ON CONFLICT(user_id) DO UPDATE SET
+ storage_provider = excluded.storage_provider,
+ storage_key = excluded.storage_key,
+ url = excluded.url,
+ content_type = excluded.content_type,
+ byte_size = excluded.byte_size,
+ sha256 = excluded.sha256,
+ updated_at = CURRENT_TIMESTAMP`,
+ []any{
+ userID,
+ input.StorageProvider,
+ input.StorageKey,
+ input.URL,
+ input.ContentType,
+ input.ByteSize,
+ input.SHA256,
+ },
+ &result,
+ ); err != nil {
+ return nil, err
+ }
+
+ return &service.UserAvatar{
+ StorageProvider: input.StorageProvider,
+ StorageKey: input.StorageKey,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+}
+
+func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var result entsql.Result
+ return driver.Exec(ctx, `DELETE FROM user_avatars WHERE user_id = ?`, []any{userID}, &result)
+}
+
+func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error {
+ panic("unexpected UpdateBalance call")
+}
+
+func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error {
+ panic("unexpected DeductBalance call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error {
+ panic("unexpected UpdateConcurrency call")
+}
+
+func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
+ count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx)
+ return count > 0, err
+}
+
+func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ panic("unexpected RemoveGroupFromAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ panic("unexpected AddGroupToAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ panic("unexpected RemoveGroupFromUserAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ identities, err := r.client.AuthIdentity.Query().
+ Where(authidentity.UserIDEQ(userID)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ records := make([]service.UserAuthIdentityRecord, 0, len(identities))
+ for _, identity := range identities {
+ if identity == nil {
+ continue
+ }
+ records = append(records, service.UserAuthIdentityRecord{
+ ProviderType: identity.ProviderType,
+ ProviderKey: identity.ProviderKey,
+ ProviderSubject: identity.ProviderSubject,
+ VerifiedAt: identity.VerifiedAt,
+ Issuer: identity.Issuer,
+ Metadata: identity.Metadata,
+ CreatedAt: identity.CreatedAt,
+ UpdatedAt: identity.UpdatedAt,
+ })
+ }
+ return records, nil
+}
+
+func (r *oauthPendingFlowUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
+ panic("unexpected UnbindUserAuthProvider call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
+ update := r.client.User.UpdateOneID(userID)
+ if encryptedSecret == nil {
+ update = update.ClearTotpSecretEncrypted()
+ } else {
+ update = update.SetTotpSecretEncrypted(*encryptedSecret)
+ }
+ return update.Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) EnableTotp(ctx context.Context, userID int64) error {
+ return r.client.User.UpdateOneID(userID).
+ SetTotpEnabled(true).
+ SetTotpEnabledAt(time.Now().UTC()).
+ Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) DisableTotp(ctx context.Context, userID int64) error {
+ return r.client.User.UpdateOneID(userID).
+ SetTotpEnabled(false).
+ ClearTotpSecretEncrypted().
+ ClearTotpEnabledAt().
+ Exec(ctx)
+}
+
+func oauthPendingFlowServiceUser(entity *dbent.User) *service.User {
+ if entity == nil {
+ return nil
+ }
+ return &service.User{
+ ID: entity.ID,
+ Email: entity.Email,
+ Username: entity.Username,
+ Notes: entity.Notes,
+ PasswordHash: entity.PasswordHash,
+ Role: entity.Role,
+ Balance: entity.Balance,
+ Concurrency: entity.Concurrency,
+ Status: entity.Status,
+ SignupSource: entity.SignupSource,
+ LastLoginAt: entity.LastLoginAt,
+ LastActiveAt: entity.LastActiveAt,
+ TotpSecretEncrypted: entity.TotpSecretEncrypted,
+ TotpEnabled: entity.TotpEnabled,
+ TotpEnabledAt: entity.TotpEnabledAt,
+ TotalRecharged: entity.TotalRecharged,
+ CreatedAt: entity.CreatedAt,
+ UpdatedAt: entity.UpdatedAt,
+ }
+}
+
+type oauthPendingFlowDefaultSubAssignerStub struct {
+ calls []service.AssignSubscriptionInput
+}
+
+func (s *oauthPendingFlowDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ if input != nil {
+ s.calls = append(s.calls, *input)
+ }
+ return nil, false, nil
+}
+
+type oauthPendingFlowTotpCacheStub struct {
+ setupSessions map[int64]*service.TotpSetupSession
+ loginSessions map[string]*service.TotpLoginSession
+ verifyAttempts map[int64]int
+}
+
+func (s *oauthPendingFlowTotpCacheStub) GetSetupSession(_ context.Context, userID int64) (*service.TotpSetupSession, error) {
+ if s == nil || s.setupSessions == nil {
+ return nil, nil
+ }
+ return s.setupSessions[userID], nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) SetSetupSession(_ context.Context, userID int64, session *service.TotpSetupSession, _ time.Duration) error {
+ if s.setupSessions == nil {
+ s.setupSessions = map[int64]*service.TotpSetupSession{}
+ }
+ s.setupSessions[userID] = session
+ return nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) DeleteSetupSession(_ context.Context, userID int64) error {
+ delete(s.setupSessions, userID)
+ return nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) GetLoginSession(_ context.Context, tempToken string) (*service.TotpLoginSession, error) {
+ if s == nil || s.loginSessions == nil {
+ return nil, nil
+ }
+ return s.loginSessions[tempToken], nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) SetLoginSession(_ context.Context, tempToken string, session *service.TotpLoginSession, _ time.Duration) error {
+ if s.loginSessions == nil {
+ s.loginSessions = map[string]*service.TotpLoginSession{}
+ }
+ s.loginSessions[tempToken] = session
+ return nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) DeleteLoginSession(_ context.Context, tempToken string) error {
+ delete(s.loginSessions, tempToken)
+ return nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) IncrementVerifyAttempts(_ context.Context, userID int64) (int, error) {
+ if s.verifyAttempts == nil {
+ s.verifyAttempts = map[int64]int{}
+ }
+ s.verifyAttempts[userID]++
+ return s.verifyAttempts[userID], nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) GetVerifyAttempts(_ context.Context, userID int64) (int, error) {
+ if s == nil || s.verifyAttempts == nil {
+ return 0, nil
+ }
+ return s.verifyAttempts[userID], nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) ClearVerifyAttempts(_ context.Context, userID int64) error {
+ delete(s.verifyAttempts, userID)
+ return nil
+}
+
+type oauthPendingFlowTotpEncryptorStub struct{}
+
+func (oauthPendingFlowTotpEncryptorStub) Encrypt(plaintext string) (string, error) {
+ return plaintext, nil
+}
+
+func (oauthPendingFlowTotpEncryptorStub) Decrypt(ciphertext string) (string, error) {
+ return ciphertext, nil
+}
diff --git a/backend/internal/handler/auth_oauth_test_helpers_test.go b/backend/internal/handler/auth_oauth_test_helpers_test.go
new file mode 100644
index 00000000..47bad942
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_test_helpers_test.go
@@ -0,0 +1,57 @@
+package handler
+
+import (
+ "net/http"
+ "net/url"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func buildEncodedOAuthBindUserCookie(t *testing.T, userID int64, secret string) string {
+ t.Helper()
+ value, err := buildOAuthBindUserCookieValue(userID, secret)
+ require.NoError(t, err)
+ return value
+}
+
+func encodedCookie(name, value string) *http.Cookie {
+ return &http.Cookie{
+ Name: name,
+ Value: encodeCookieValue(value),
+ Path: "/",
+ }
+}
+
+func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
+ for _, cookie := range cookies {
+ if cookie.Name == name {
+ return cookie
+ }
+ }
+ return nil
+}
+
+func decodeCookieValueForTest(t *testing.T, value string) string {
+ t.Helper()
+ decoded, err := decodeCookieValue(value)
+ require.NoError(t, err)
+ return decoded
+}
+
+func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
+ t.Helper()
+ require.NotEmpty(t, location)
+
+ parsed, err := url.Parse(location)
+ require.NoError(t, err)
+
+ rawValues := parsed.RawQuery
+ if rawValues == "" {
+ rawValues = parsed.Fragment
+ }
+ values, err := url.ParseQuery(rawValues)
+ require.NoError(t, err)
+ require.Equal(t, errorCode, values.Get("error"))
+ require.Equal(t, errorMessage, values.Get("error_message"))
+}
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 9d24df88..4264002d 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -19,6 +19,7 @@ import (
"strings"
"time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
@@ -32,14 +33,16 @@ import (
)
const (
- oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc"
- oidcOAuthStateCookieName = "oidc_oauth_state"
- oidcOAuthVerifierCookie = "oidc_oauth_verifier"
- oidcOAuthRedirectCookie = "oidc_oauth_redirect"
- oidcOAuthNonceCookie = "oidc_oauth_nonce"
- oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
- oidcOAuthDefaultRedirectTo = "/dashboard"
- oidcOAuthDefaultFrontendCB = "/auth/oidc/callback"
+ oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc"
+ oidcOAuthStateCookieName = "oidc_oauth_state"
+ oidcOAuthVerifierCookie = "oidc_oauth_verifier"
+ oidcOAuthRedirectCookie = "oidc_oauth_redirect"
+ oidcOAuthNonceCookie = "oidc_oauth_nonce"
+ oidcOAuthIntentCookieName = "oidc_oauth_intent"
+ oidcOAuthBindUserCookieName = "oidc_oauth_bind_user"
+ oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
+ oidcOAuthDefaultRedirectTo = "/dashboard"
+ oidcOAuthDefaultFrontendCB = "/auth/oidc/callback"
)
type oidcTokenResponse struct {
@@ -87,6 +90,8 @@ type oidcUserInfoClaims struct {
Username string
Subject string
EmailVerified *bool
+ DisplayName string
+ AvatarURL string
}
type oidcJWKSet struct {
@@ -127,9 +132,29 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
redirectTo = oidcOAuthDefaultRedirectTo
}
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
secureCookie := isRequestHTTPS(c)
oidcSetCookie(c, oidcOAuthStateCookieName, encodeCookieValue(state), oidcOAuthCookieMaxAgeSec, secureCookie)
oidcSetCookie(c, oidcOAuthRedirectCookie, encodeCookieValue(redirectTo), oidcOAuthCookieMaxAgeSec, secureCookie)
+ intent := normalizeOAuthIntent(c.Query("intent"))
+ oidcSetCookie(c, oidcOAuthIntentCookieName, encodeCookieValue(intent), oidcOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ if intent == oauthIntentBindCurrentUser {
+ bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ oidcSetCookie(c, oidcOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), oidcOAuthCookieMaxAgeSec, secureCookie)
+ } else {
+ oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
+ }
codeChallenge := ""
if cfg.UsePKCE {
@@ -199,6 +224,8 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie)
+ oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
}()
expectedState, err := readCookieDecoded(c, oidcOAuthStateCookieName)
@@ -212,6 +239,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
if redirectTo == "" {
redirectTo = oidcOAuthDefaultRedirectTo
}
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
+ intent, _ := readCookieDecoded(c, oidcOAuthIntentCookieName)
+ intent = normalizeOAuthIntent(intent)
codeVerifier := ""
if cfg.UsePKCE {
@@ -258,16 +292,19 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
- if cfg.ValidateIDToken && strings.TrimSpace(tokenResp.IDToken) == "" {
- redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
- return
- }
+ var idClaims *oidcIDTokenClaims
+ if cfg.ValidateIDToken {
+ if strings.TrimSpace(tokenResp.IDToken) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
+ return
+ }
- idClaims, err := oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
- if err != nil {
- log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
- redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
- return
+ idClaims, err = oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
+ if err != nil {
+ log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
+ redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
+ return
+ }
}
userInfoClaims, err := oidcFetchUserInfo(c.Request.Context(), cfg, tokenResp)
@@ -277,7 +314,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
- subject := strings.TrimSpace(idClaims.Subject)
+ subject := ""
+ if idClaims != nil {
+ subject = strings.TrimSpace(idClaims.Subject)
+ }
if subject == "" {
subject = strings.TrimSpace(userInfoClaims.Subject)
}
@@ -285,7 +325,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
redirectOAuthError(c, frontendCallback, "missing_subject", "missing subject claim", "")
return
}
- issuer := strings.TrimSpace(idClaims.Issuer)
+ issuer := ""
+ if idClaims != nil {
+ issuer = strings.TrimSpace(idClaims.Issuer)
+ }
if issuer == "" {
issuer = strings.TrimSpace(cfg.IssuerURL)
}
@@ -295,9 +338,115 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
emailVerified := userInfoClaims.EmailVerified
- if emailVerified == nil {
+ if emailVerified == nil && idClaims != nil {
emailVerified = idClaims.EmailVerified
}
+ if idClaims != nil && userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
+ redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "")
+ return
+ }
+
+ identityKey := oidcIdentityKey(issuer, subject)
+ compatEmail := strings.TrimSpace(userInfoClaims.Email)
+ if compatEmail == "" && idClaims != nil {
+ compatEmail = strings.TrimSpace(idClaims.Email)
+ }
+ email := oidcSyntheticEmailFromIdentityKey(identityKey)
+ username := firstNonEmpty(
+ userInfoClaims.Username,
+ func() string {
+ if idClaims != nil {
+ return idClaims.PreferredUsername
+ }
+ return ""
+ }(),
+ func() string {
+ if idClaims != nil {
+ return idClaims.Name
+ }
+ return ""
+ }(),
+ oidcFallbackUsername(subject),
+ )
+ identityRef := service.PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: issuer,
+ ProviderSubject: subject,
+ }
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string {
+ if idClaims != nil {
+ return idClaims.Name
+ }
+ return ""
+ }(), username),
+ "suggested_avatar_url": userInfoClaims.AvatarURL,
+ }
+ if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
+ upstreamClaims["compat_email"] = compatEmail
+ }
+ if intent == oauthIntentBindCurrentUser {
+ targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "")
+ return
+ }
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentBindCurrentUser,
+ Identity: identityRef,
+ TargetUserID: &targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser != nil {
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identityRef,
+ TargetUserID: &existingIdentityUser.ID,
+ ResolvedEmail: existingIdentityUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ compatEmailUser, err := h.findOIDCCompatEmailUser(c.Request.Context(), compatEmail)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+
if cfg.RequireEmailVerified {
if emailVerified == nil || !*emailVerified {
redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "")
@@ -305,47 +454,137 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
}
- identityKey := oidcIdentityKey(issuer, subject)
- email := oidcSelectLoginEmail(userInfoClaims.Email, idClaims.Email, identityKey)
- username := firstNonEmpty(
- userInfoClaims.Username,
- idClaims.PreferredUsername,
- idClaims.Name,
- oidcFallbackUsername(subject),
- )
-
- // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
- if err != nil {
- if errors.Is(err, service.ErrOAuthInvitationRequired) {
- pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
- if tokenErr != nil {
- redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
- return
- }
- fragment := url.Values{}
- fragment.Set("error", "invitation_required")
- fragment.Set("pending_oauth_token", pendingToken)
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
+ if err := h.createOIDCOAuthChoicePendingSession(
+ c,
+ identityRef,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ compatEmail,
+ compatEmailUser,
+ true,
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
- redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
+ redirectToFrontendCallback(c, frontendCallback)
return
}
- fragment := url.Values{}
- fragment.Set("access_token", tokenPair.AccessToken)
- fragment.Set("refresh_token", tokenPair.RefreshToken)
- fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
- fragment.Set("token_type", "Bearer")
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ if err := h.createOIDCOAuthChoicePendingSession(
+ c,
+ identityRef,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ compatEmail,
+ compatEmailUser,
+ h.isForceEmailOnThirdPartySignup(c.Request.Context()),
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+}
+
+func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" ||
+ strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
+ return nil, nil
+ }
+
+ userEntity, err := findUserByNormalizedEmail(ctx, client, email)
+ if err != nil {
+ if errors.Is(err, service.ErrUserNotFound) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
+ }
+ return userEntity, nil
+}
+
+func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
+ c *gin.Context,
+ identity service.PendingAuthIdentityKey,
+ suggestedEmail string,
+ resolvedEmail string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ compatEmail string,
+ compatEmailUser *dbent.User,
+ forceEmailOnSignup bool,
+) error {
+ suggestionEmail := strings.TrimSpace(suggestedEmail)
+ canonicalEmail := strings.TrimSpace(resolvedEmail)
+ if suggestionEmail == "" {
+ suggestionEmail = canonicalEmail
+ }
+
+ completionResponse := map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "redirect": strings.TrimSpace(redirectTo),
+ "email": suggestionEmail,
+ "resolved_email": canonicalEmail,
+ "existing_account_email": "",
+ "existing_account_bindable": false,
+ "create_account_allowed": true,
+ "force_email_on_signup": forceEmailOnSignup,
+ "choice_reason": "third_party_signup",
+ }
+ if strings.TrimSpace(compatEmail) != "" {
+ completionResponse["compat_email"] = strings.TrimSpace(compatEmail)
+ }
+ if compatEmailUser != nil {
+ completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_bindable"] = true
+ completionResponse["choice_reason"] = "compat_email_match"
+ }
+ if forceEmailOnSignup && compatEmailUser == nil {
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ }
+
+ resolvedChoiceEmail := suggestionEmail
+ if compatEmailUser != nil {
+ resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
+ }
+ var targetUserID *int64
+ if compatEmailUser != nil && compatEmailUser.ID > 0 {
+ targetUserID = &compatEmailUser.ID
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identity,
+ TargetUserID: targetUserID,
+ ResolvedEmail: resolvedChoiceEmail,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
}
type completeOIDCOAuthRequest struct {
- PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AffCode string `json:"aff_code,omitempty"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
@@ -358,17 +597,87 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return
}
- email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
return
}
-
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
if err != nil {
response.ErrorFrom(c, err)
return
}
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ } else if handled {
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
+ return
+ } else {
+ session = updatedSession
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+ if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
@@ -405,7 +714,7 @@ func oidcExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
- if cfg.UsePKCE {
+ if strings.TrimSpace(codeVerifier) != "" {
form.Set("code_verifier", codeVerifier)
}
@@ -560,9 +869,26 @@ func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoC
if verified, ok := getGJSONBool(body, "email_verified"); ok {
claims.EmailVerified = &verified
}
+ claims.DisplayName = firstNonEmpty(
+ getGJSON(body, "name"),
+ getGJSON(body, "nickname"),
+ getGJSON(body, "display_name"),
+ getGJSON(body, "preferred_username"),
+ getGJSON(body, "username"),
+ )
+ claims.AvatarURL = firstNonEmpty(
+ getGJSON(body, "picture"),
+ getGJSON(body, "avatar_url"),
+ getGJSON(body, "avatar"),
+ getGJSON(body, "profile_image_url"),
+ getGJSON(body, "user.avatar"),
+ getGJSON(body, "user.avatar_url"),
+ )
claims.Email = strings.TrimSpace(claims.Email)
claims.Username = strings.TrimSpace(claims.Username)
claims.Subject = strings.TrimSpace(claims.Subject)
+ claims.DisplayName = strings.TrimSpace(claims.DisplayName)
+ claims.AvatarURL = strings.TrimSpace(claims.AvatarURL)
return claims
}
@@ -595,7 +921,7 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall
if strings.TrimSpace(nonce) != "" {
q.Set("nonce", nonce)
}
- if cfg.UsePKCE {
+ if strings.TrimSpace(codeChallenge) != "" {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
@@ -831,14 +1157,6 @@ func oidcSyntheticEmailFromIdentityKey(identityKey string) string {
return "oidc-" + hex.EncodeToString(sum[:16]) + service.OIDCConnectSyntheticEmailDomain
}
-func oidcSelectLoginEmail(userInfoEmail, idTokenEmail, identityKey string) string {
- email := strings.TrimSpace(firstNonEmpty(userInfoEmail, idTokenEmail))
- if email != "" {
- return email
- }
- return oidcSyntheticEmailFromIdentityKey(identityKey)
-}
-
func oidcFallbackUsername(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go
index a161aa77..3216d51e 100644
--- a/backend/internal/handler/auth_oidc_oauth_test.go
+++ b/backend/internal/handler/auth_oidc_oauth_test.go
@@ -1,6 +1,7 @@
package handler
import (
+ "bytes"
"context"
"crypto/rand"
"crypto/rsa"
@@ -12,7 +13,15 @@ import (
"testing"
"time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
+ servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
@@ -30,26 +39,11 @@ func TestOIDCSyntheticEmailStableAndDistinct(t *testing.T) {
require.Contains(t, e1, "@oidc-connect.invalid")
}
-func TestOIDCSelectLoginEmailPrefersRealEmail(t *testing.T) {
- identityKey := oidcIdentityKey("https://issuer.example.com", "subject-a")
-
- email := oidcSelectLoginEmail("user@example.com", "idtoken@example.com", identityKey)
- require.Equal(t, "user@example.com", email)
-
- email = oidcSelectLoginEmail("", "idtoken@example.com", identityKey)
- require.Equal(t, "idtoken@example.com", email)
-
- email = oidcSelectLoginEmail("", "", identityKey)
- require.Contains(t, email, "@oidc-connect.invalid")
- require.Equal(t, oidcSyntheticEmailFromIdentityKey(identityKey), email)
-}
-
func TestBuildOIDCAuthorizeURLIncludesNonceAndPKCE(t *testing.T) {
cfg := config.OIDCConnectConfig{
AuthorizeURL: "https://issuer.example.com/auth",
ClientID: "cid",
Scopes: "openid email profile",
- UsePKCE: true,
}
u, err := buildOIDCAuthorizeURL(cfg, "state123", "nonce123", "challenge123", "https://app.example.com/callback")
@@ -106,6 +100,26 @@ func TestOIDCParseAndValidateIDToken(t *testing.T) {
require.Error(t, err)
}
+func TestOIDCParseUserInfoIncludesSuggestedProfile(t *testing.T) {
+ cfg := config.OIDCConnectConfig{}
+
+ claims := oidcParseUserInfo(`{
+ "sub":"subject-1",
+ "preferred_username":"alice",
+ "name":"Alice Example",
+ "picture":"https://cdn.example/avatar.png",
+ "email":"alice@example.com",
+ "email_verified":true
+ }`, cfg)
+
+ require.Equal(t, "subject-1", claims.Subject)
+ require.Equal(t, "alice", claims.Username)
+ require.Equal(t, "Alice Example", claims.DisplayName)
+ require.Equal(t, "https://cdn.example/avatar.png", claims.AvatarURL)
+ require.NotNil(t, claims.EmailVerified)
+ require.True(t, *claims.EmailVerified)
+}
+
func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes())
e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes())
@@ -118,3 +132,909 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
E: e,
}
}
+
+func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
+ handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{
+ Enabled: true,
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/oauth/authorize",
+ TokenURL: "https://issuer.example.com/oauth/token",
+ UserInfoURL: "https://issuer.example.com/oauth/userinfo",
+ JWKSURL: "https://issuer.example.com/oauth/jwks",
+ Scopes: "openid profile email",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ RequireEmailVerified: false,
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=/settings/connections", nil)
+ c.Request = req
+ c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 84})
+
+ handler.OIDCOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.Contains(t, location, "issuer.example.com/oauth/authorize")
+ require.Contains(t, location, "client_id=oidc-client")
+ require.Contains(t, location, "nonce=")
+
+ cookies := recorder.Result().Cookies()
+ require.NotNil(t, findCookie(cookies, oidcOAuthStateCookieName))
+ require.NotNil(t, findCookie(cookies, oidcOAuthRedirectCookie))
+ require.NotNil(t, findCookie(cookies, oidcOAuthVerifierCookie))
+ require.NotNil(t, findCookie(cookies, oidcOAuthNonceCookie))
+ require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName))
+
+ intentCookie := findCookie(cookies, oidcOAuthIntentCookieName)
+ require.NotNil(t, intentCookie)
+ require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value))
+
+ bindCookie := findCookie(cookies, oidcOAuthBindUserCookieName)
+ require.NotNil(t, bindCookie)
+ userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret")
+ require.NoError(t, err)
+ require.Equal(t, int64(84), userID)
+}
+
+func TestOIDCOAuthStartOmitsPKCEAndNonceWhenDisabled(t *testing.T) {
+ handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{
+ Enabled: true,
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/oauth/authorize",
+ TokenURL: "https://issuer.example.com/oauth/token",
+ UserInfoURL: "https://issuer.example.com/oauth/userinfo",
+ Scopes: "openid profile email",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: false,
+ ValidateIDToken: false,
+ RequireEmailVerified: false,
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/start?redirect=/dashboard", nil)
+
+ handler.OIDCOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.NotContains(t, location, "code_challenge=")
+ require.NotContains(t, location, "nonce=")
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthVerifierCookie))
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthNonceCookie))
+}
+
+func TestOIDCOAuthCallbackAllowsOptionalPKCEAndIDTokenValidation(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ require.NoError(t, r.ParseForm())
+ require.Empty(t, r.PostForm.Get("code_verifier"))
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"oidc-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"sub":"oidc-subject-compat","preferred_username":"oidc_user","name":"OIDC Display","email":"oidc@example.com"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, config.OIDCConnectConfig{
+ Enabled: true,
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "openid profile email",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: false,
+ ValidateIDToken: false,
+ RequireEmailVerified: false,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+ require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+}
+
+func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-login",
+ PreferredUsername: "oidc_login",
+ DisplayName: "OIDC Login Display",
+ AvatarURL: "https://cdn.example/oidc-login.png",
+ Email: "oidc-login@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-subject-login"))).
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("oidc").
+ SetProviderKey(cfg.IssuerURL).
+ SetProviderSubject("oidc-subject-login").
+ SetMetadata(map[string]any{"username": "legacy-user"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-123"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-login"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, cfg.IssuerURL, session.ProviderKey)
+ require.Equal(t, "OIDC Login Display", session.UpstreamIdentityClaims["suggested_display_name"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+ _, hasRefreshToken := completion["refresh_token"]
+ require.False(t, hasRefreshToken)
+ require.Nil(t, completion["error"])
+}
+
+func TestOIDCOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-disabled-subject",
+ PreferredUsername: "oidc_disabled",
+ DisplayName: "OIDC Disabled",
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-disabled-subject"))).
+ SetUsername("disabled-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusDisabled).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("oidc").
+ SetProviderKey(cfg.IssuerURL).
+ SetProviderSubject("oidc-disabled-subject").
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-disabled", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-disabled"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-disabled"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-disabled-subject"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-compat",
+ PreferredUsername: "oidc_compat",
+ DisplayName: "OIDC Compat Display",
+ AvatarURL: "https://cdn.example/oidc-compat.png",
+ Email: "legacy@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-compat", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, existingUser.Email, session.ResolvedEmail)
+ require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, existingUser.Email, completion["email"])
+ require.Equal(t, existingUser.Email, completion["existing_account_email"])
+ require.Equal(t, true, completion["existing_account_bindable"])
+ require.Equal(t, "compat_email_match", completion["choice_reason"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+}
+
+func TestOIDCOAuthCallbackAllowsCompatEmailBindWhenUpstreamEmailIsUnverified(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-unverified-compat",
+ PreferredUsername: "oidc_unverified",
+ DisplayName: "OIDC Unverified Compat Display",
+ AvatarURL: "https://cdn.example/oidc-unverified.png",
+ Email: "owner@example.com",
+ EmailVerified: false,
+ })
+ defer cleanup()
+ cfg.RequireEmailVerified = true
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ _, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-unverified-compat", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-unverified-compat"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback#error=email_not_verified&error_message=email+is+not+verified", recorder.Header().Get("Location"))
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestOIDCOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-invite",
+ PreferredUsername: "oidc_invite",
+ DisplayName: "OIDC Invite Display",
+ AvatarURL: "https://cdn.example/oidc-invite.png",
+ Email: "oidc-invite@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, true, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-456", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-456"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-456"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-invite"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.Nil(t, session.TargetUserID)
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, "third_party_signup", completion["choice_reason"])
+}
+
+func TestOIDCOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-bind",
+ PreferredUsername: "oidc_bind",
+ DisplayName: "OIDC Bind Display",
+ AvatarURL: "https://cdn.example/oidc-bind.png",
+ Email: "oidc-bind@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-bind", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-bind"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-bind"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-bind"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentBindCurrentUser))
+ req.AddCookie(encodedCookie(oidcOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentBindCurrentUser, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, currentUser.ID, *session.TargetUserID)
+ require.Equal(t, cfg.IssuerURL, session.ProviderKey)
+ require.Equal(t, "OIDC Bind Display", session.UpstreamIdentityClaims["suggested_display_name"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/settings/connections", completion["redirect"])
+ require.Empty(t, completion["access_token"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, userCount)
+}
+
+func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-subject-1").
+ SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ "suggested_display_name": "OIDC Display",
+ "suggested_avatar_url": "https://cdn.example/oidc.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "OIDC Display", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example.com"),
+ authidentity.ProviderSubjectEQ("oidc-subject-1"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "OIDC Display", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-invalid-session").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-invalid-subject-1").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("oidc-invalid-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "bind_login_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-invalid-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-choice-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-choice-subject-1").
+ SetResolvedEmail("oidc-choice-subject-1@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-choice-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": oauthPendingChoiceStep,
+ "redirect": "/dashboard",
+ "email": "fresh@example.com",
+ "resolved_email": "fresh@example.com",
+ "force_email_on_signup": true,
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-choice-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.Equal(t, "pending_session", responseData["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, responseData["step"])
+ require.Equal(t, true, responseData["force_email_on_signup"])
+ require.Empty(t, responseData["access_token"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-no-adoption-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-subject-no-adoption").
+ SetResolvedEmail("8c9f12b2a2e14b1db9efc08b27e0ef5c@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-browser-no-adoption").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ "suggested_display_name": "OIDC Legacy",
+ "suggested_avatar_url": "https://cdn.example/oidc-legacy.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser-no-adoption")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+ require.NotEmpty(t, responseData["refresh_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "oidc_user", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example.com"),
+ authidentity.ProviderSubjectEQ("oidc-subject-no-adoption"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+}
+
+func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingOwner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-conflict-subject").
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-conflict-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-conflict-subject").
+ SetResolvedEmail("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-conflict-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-conflict-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusConflict, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
+
+ userCount, err := client.User.Query().
+ Where(dbuser.EmailEQ("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid")).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+type oidcProviderFixture struct {
+ Subject string
+ PreferredUsername string
+ DisplayName string
+ AvatarURL string
+ Email string
+ EmailVerified bool
+}
+
+func newOIDCOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.OIDCConnectConfig) *AuthHandler {
+ t.Helper()
+ handler, _ := newOIDCOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
+ return handler
+}
+
+func newOIDCOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.OIDCConnectConfig) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+ handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled)
+ handler.settingSvc = nil
+ handler.cfg = &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ OIDC: oauthCfg,
+ }
+ return handler, client
+}
+
+func newOIDCTestProvider(t *testing.T, fixture oidcProviderFixture) (config.OIDCConnectConfig, func()) {
+ t.Helper()
+
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ kid := "test-kid"
+ jwks := oidcJWKSet{Keys: []oidcJWK{buildRSAJWK(kid, &privateKey.PublicKey)}}
+ tokenResponse := oidcTokenResponse{
+ AccessToken: "oidc-access-token",
+ TokenType: "Bearer",
+ ExpiresIn: 3600,
+ }
+
+ userInfoPayload := map[string]any{
+ "sub": fixture.Subject,
+ "preferred_username": fixture.PreferredUsername,
+ "name": fixture.DisplayName,
+ "picture": fixture.AvatarURL,
+ "email": fixture.Email,
+ "email_verified": fixture.EmailVerified,
+ }
+
+ var issuer string
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ require.NoError(t, json.NewEncoder(w).Encode(tokenResponse))
+ case "/userinfo":
+ require.NoError(t, json.NewEncoder(w).Encode(userInfoPayload))
+ case "/jwks":
+ require.NoError(t, json.NewEncoder(w).Encode(jwks))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+
+ issuer = server.URL
+ now := time.Now()
+ claims := oidcIDTokenClaims{
+ Email: fixture.Email,
+ EmailVerified: boolPtr(fixture.EmailVerified),
+ PreferredUsername: fixture.PreferredUsername,
+ Name: fixture.DisplayName,
+ Nonce: "nonce-" + fixture.Subject,
+ RegisteredClaims: jwt.RegisteredClaims{
+ Issuer: issuer,
+ Subject: fixture.Subject,
+ Audience: jwt.ClaimStrings{"oidc-client"},
+ IssuedAt: jwt.NewNumericDate(now),
+ NotBefore: jwt.NewNumericDate(now.Add(-30 * time.Second)),
+ ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)),
+ },
+ }
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ token.Header["kid"] = kid
+ tokenResponse.IDToken, err = token.SignedString(privateKey)
+ require.NoError(t, err)
+
+ cfg := config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "Test OIDC",
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: issuer,
+ AuthorizeURL: issuer + "/authorize",
+ TokenURL: issuer + "/token",
+ UserInfoURL: issuer + "/userinfo",
+ JWKSURL: issuer + "/jwks",
+ Scopes: "openid profile email",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ RequireEmailVerified: false,
+ }
+ return cfg, server.Close
+}
diff --git a/backend/internal/handler/auth_session_revocation_test.go b/backend/internal/handler/auth_session_revocation_test.go
new file mode 100644
index 00000000..f1c6d87d
--- /dev/null
+++ b/backend/internal/handler/auth_session_revocation_test.go
@@ -0,0 +1,61 @@
+//go:build unit
+
+package handler
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 29,
+ Email: "session@example.com",
+ Username: "session-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 7,
+ },
+ }
+ refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
+ handler := &AuthHandler{authService: authService}
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/auth/revoke-all-sessions", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 29})
+
+ handler.RevokeAllSessions(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, []int64{29}, refreshTokenCache.revokedUserIDs)
+ require.Equal(t, int64(8), repo.user.TokenVersion)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Message string `json:"message"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "All sessions have been revoked. Please log in again.", resp.Data.Message)
+}
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
new file mode 100644
index 00000000..34e70ed0
--- /dev/null
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -0,0 +1,1350 @@
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat"
+ wechatOAuthCookieMaxAgeSec = 10 * 60
+ wechatOAuthStateCookieName = "wechat_oauth_state"
+ wechatOAuthRedirectCookieName = "wechat_oauth_redirect"
+ wechatOAuthIntentCookieName = "wechat_oauth_intent"
+ wechatOAuthModeCookieName = "wechat_oauth_mode"
+ wechatOAuthBindUserCookieName = "wechat_oauth_bind_user"
+ wechatOAuthDefaultRedirectTo = "/dashboard"
+ wechatOAuthDefaultFrontendCB = "/auth/wechat/callback"
+ wechatOAuthProviderKey = "wechat-main"
+ wechatOAuthLegacyProviderKey = "wechat"
+ wechatPaymentOAuthCookiePath = "/api/v1/auth/oauth/wechat/payment"
+ wechatPaymentOAuthStateName = "wechat_payment_oauth_state"
+ wechatPaymentOAuthRedirect = "wechat_payment_oauth_redirect"
+ wechatPaymentOAuthContextName = "wechat_payment_oauth_context"
+ wechatPaymentOAuthScope = "wechat_payment_oauth_scope"
+ wechatPaymentOAuthDefaultTo = "/purchase"
+ wechatPaymentOAuthFrontendCB = "/auth/wechat/payment/callback"
+
+ wechatOAuthIntentLogin = "login"
+ wechatOAuthIntentBind = "bind_current_user"
+ wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email"
+)
+
+var (
+ wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo"
+)
+
+type wechatOAuthConfig struct {
+ mode string
+ appID string
+ appSecret string
+ authorizeURL string
+ scope string
+ redirectURI string
+ frontendCallback string
+ openEnabled bool
+ mpEnabled bool
+}
+
+type wechatOAuthTokenResponse struct {
+ AccessToken string `json:"access_token"`
+ ExpiresIn int64 `json:"expires_in"`
+ RefreshToken string `json:"refresh_token"`
+ OpenID string `json:"openid"`
+ Scope string `json:"scope"`
+ UnionID string `json:"unionid"`
+ ErrCode int64 `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+}
+
+type wechatOAuthUserInfoResponse struct {
+ OpenID string `json:"openid"`
+ Nickname string `json:"nickname"`
+ HeadImgURL string `json:"headimgurl"`
+ UnionID string `json:"unionid"`
+ ErrCode int64 `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+}
+
+type wechatPaymentOAuthContext struct {
+ PaymentType string `json:"payment_type"`
+ Amount string `json:"amount,omitempty"`
+ OrderType string `json:"order_type,omitempty"`
+ PlanID int64 `json:"plan_id,omitempty"`
+}
+
+// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived
+// browser cookies required by the rebuild pending-auth bridge.
+func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) {
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ state, err := oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
+ return
+ }
+
+ redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
+ if redirectTo == "" {
+ redirectTo = wechatOAuthDefaultRedirectTo
+ }
+
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
+ intent := normalizeWeChatOAuthIntent(c.Query("intent"))
+ secureCookie := isRequestHTTPS(c)
+ wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ if intent == oauthIntentBindCurrentUser {
+ bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ wechatSetCookie(c, wechatOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), wechatOAuthCookieMaxAgeSec, secureCookie)
+ } else {
+ wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
+ }
+
+ authURL, err := buildWeChatAuthorizeURL(cfg, state)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
+ return
+ }
+
+ c.Redirect(http.StatusFound, authURL)
+}
+
+// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid,
+// and stores the result in the unified pending-auth flow.
+func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
+ frontendCallback := h.wechatOAuthFrontendCallback(c.Request.Context())
+
+ if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
+ return
+ }
+
+ code := strings.TrimSpace(c.Query("code"))
+ state := strings.TrimSpace(c.Query("state"))
+ if code == "" || state == "" {
+ redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ defer func() {
+ wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
+ }()
+
+ expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName)
+ if err != nil || expectedState == "" || state != expectedState {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
+ return
+ }
+
+ redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName)
+ redirectTo = sanitizeFrontendRedirectPath(redirectTo)
+ if redirectTo == "" {
+ redirectTo = wechatOAuthDefaultRedirectTo
+ }
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
+
+ intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName)
+ mode, err := readCookieDecoded(c, wechatOAuthModeCookieName)
+ if err != nil || strings.TrimSpace(mode) == "" {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "")
+ return
+ }
+
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+
+ tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error()))
+ return
+ }
+
+ unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID))
+ openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID))
+ providerSubject := unionid
+ if providerSubject == "" {
+ if cfg.requiresUnionID() {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "")
+ return
+ }
+ providerSubject = openid
+ }
+ if providerSubject == "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "")
+ return
+ }
+
+ username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject))
+ email := wechatSyntheticEmail(providerSubject)
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": providerSubject,
+ "openid": openid,
+ "unionid": unionid,
+ "mode": cfg.mode,
+ "channel": cfg.mode,
+ "channel_app_id": strings.TrimSpace(cfg.appID),
+ "channel_subject": openid,
+ "suggested_display_name": strings.TrimSpace(userInfo.Nickname),
+ "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL),
+ }
+ identityRef := service.PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: wechatOAuthProviderKey,
+ ProviderSubject: providerSubject,
+ }
+
+ normalizedIntent := normalizeWeChatOAuthIntent(intent)
+ if normalizedIntent == wechatOAuthIntentBind {
+ if err := h.createWeChatBindPendingSession(c, cfg, providerSubject, openid, redirectTo, browserSessionKey, upstreamClaims); err != nil {
+ switch infraerrors.Code(err) {
+ case http.StatusConflict:
+ redirectOAuthError(c, frontendCallback, "ownership_conflict", infraerrors.Reason(err), infraerrors.Message(err))
+ case http.StatusUnauthorized, http.StatusForbidden:
+ redirectOAuthError(c, frontendCallback, "auth_required", infraerrors.Reason(err), infraerrors.Message(err))
+ default:
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ }
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser == nil {
+ existingIdentityUser, err = h.findWeChatUserByLegacyOpenID(c.Request.Context(), identityRef, cfg, openid)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ }
+ if existingIdentityUser != nil {
+ if err := h.ensureWeChatRuntimeIdentityBinding(c.Request.Context(), existingIdentityUser.ID, identityRef, upstreamClaims); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, nil, nil, &existingIdentityUser.ID); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
+ if err := h.createWeChatChoicePendingSession(
+ c,
+ identityRef,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ "",
+ nil,
+ true,
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if err := h.createWeChatChoicePendingSession(
+ c,
+ identityRef,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ "",
+ nil,
+ false,
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+}
+
+// WeChatPaymentOAuthStart starts the WeChat payment OAuth flow.
+// GET /api/v1/auth/oauth/wechat/payment/start?payment_type=wxpay&redirect=/purchase
+func (h *AuthHandler) WeChatPaymentOAuthStart(c *gin.Context) {
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), "mp", c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ paymentType := normalizeWeChatPaymentType(c.Query("payment_type"))
+ if paymentType == "" {
+ response.BadRequest(c, "Invalid payment type")
+ return
+ }
+
+ state, err := oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
+ return
+ }
+
+ redirectTo := normalizeWeChatPaymentRedirectPath(sanitizeFrontendRedirectPath(c.Query("redirect")))
+ if redirectTo == "" {
+ redirectTo = wechatPaymentOAuthDefaultTo
+ }
+ rawContext, err := encodeWeChatPaymentOAuthContext(wechatPaymentOAuthContext{
+ PaymentType: paymentType,
+ Amount: strings.TrimSpace(c.Query("amount")),
+ OrderType: strings.TrimSpace(c.Query("order_type")),
+ PlanID: parseWeChatPaymentPlanID(c.Query("plan_id")),
+ })
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONTEXT_ENCODE_FAILED", "failed to encode oauth context").WithCause(err))
+ return
+ }
+
+ scope := normalizeWeChatPaymentScope(c.Query("scope"))
+ secureCookie := isRequestHTTPS(c)
+ wechatPaymentSetCookie(c, wechatPaymentOAuthStateName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatPaymentSetCookie(c, wechatPaymentOAuthRedirect, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatPaymentSetCookie(c, wechatPaymentOAuthContextName, encodeCookieValue(rawContext), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatPaymentSetCookie(c, wechatPaymentOAuthScope, encodeCookieValue(scope), wechatOAuthCookieMaxAgeSec, secureCookie)
+
+ cfg.redirectURI = h.resolveWeChatPaymentOAuthCallbackURL(c.Request.Context(), c)
+ cfg.scope = scope
+ authURL, err := buildWeChatAuthorizeURL(cfg, state)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
+ return
+ }
+
+ c.Redirect(http.StatusFound, authURL)
+}
+
+// WeChatPaymentOAuthCallback exchanges a payment OAuth code for an OpenID and
+// forwards the browser back to the frontend callback route.
+func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) {
+ frontendCallback := wechatPaymentOAuthFrontendCB
+
+ if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
+ return
+ }
+
+ code := strings.TrimSpace(c.Query("code"))
+ state := strings.TrimSpace(c.Query("state"))
+ if code == "" || state == "" {
+ redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ defer func() {
+ wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie)
+ }()
+
+ expectedState, err := readCookieDecoded(c, wechatPaymentOAuthStateName)
+ if err != nil || expectedState == "" || state != expectedState {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
+ return
+ }
+
+ redirectTo, _ := readCookieDecoded(c, wechatPaymentOAuthRedirect)
+ redirectTo = normalizeWeChatPaymentRedirectPath(sanitizeFrontendRedirectPath(redirectTo))
+ if redirectTo == "" {
+ redirectTo = wechatPaymentOAuthDefaultTo
+ }
+
+ rawContext, _ := readCookieDecoded(c, wechatPaymentOAuthContextName)
+ paymentContext, err := decodeWeChatPaymentOAuthContext(rawContext)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_context", "invalid oauth context", "")
+ return
+ }
+ if paymentContext.PaymentType == "" {
+ paymentContext.PaymentType = payment.TypeWxpay
+ }
+
+ scope, _ := readCookieDecoded(c, wechatPaymentOAuthScope)
+ scope = normalizeWeChatPaymentScope(scope)
+
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), "mp", c)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ cfg.redirectURI = h.resolveWeChatPaymentOAuthCallbackURL(c.Request.Context(), c)
+ tokenResp, err := exchangeWeChatOAuthCode(c.Request.Context(), cfg, code)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", err.Error())
+ return
+ }
+
+ openid := strings.TrimSpace(tokenResp.OpenID)
+ if openid == "" {
+ redirectOAuthError(c, frontendCallback, "missing_openid", "missing openid", "")
+ return
+ }
+ if strings.TrimSpace(tokenResp.Scope) != "" {
+ scope = strings.TrimSpace(tokenResp.Scope)
+ }
+
+ resumeToken, err := h.wechatPaymentResumeService().CreateWeChatPaymentResumeToken(service.WeChatPaymentResumeClaims{
+ OpenID: openid,
+ PaymentType: paymentContext.PaymentType,
+ Amount: paymentContext.Amount,
+ OrderType: paymentContext.OrderType,
+ PlanID: paymentContext.PlanID,
+ RedirectTo: redirectTo,
+ Scope: scope,
+ })
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_context", "failed to encode payment resume context", "")
+ return
+ }
+
+ fragment := url.Values{}
+ fragment.Set("wechat_resume_token", resumeToken)
+ fragment.Set("redirect", redirectTo)
+ redirectWithFragment(c, frontendCallback, fragment)
+}
+
+func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService {
+ var legacyKey []byte
+ key, err := payment.ProvideEncryptionKey(h.cfg)
+ if err == nil {
+ legacyKey = []byte(key)
+ }
+ return service.NewLegacyAwarePaymentResumeService(legacyKey)
+}
+
+type completeWeChatOAuthRequest struct {
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AffCode string `json:"aff_code,omitempty"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by
+// validating the invitation code and consuming the current pending browser session.
+// POST /api/v1/auth/oauth/wechat/complete-registration
+func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
+ var req completeWeChatOAuthRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ } else if handled {
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
+ return
+ } else {
+ session = updatedSession
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
+ return
+ }
+
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ })
+}
+
+func (h *AuthHandler) createWeChatPendingSession(
+ c *gin.Context,
+ intent string,
+ providerSubject string,
+ email string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ tokenPair *service.TokenPair,
+ authErr error,
+ targetUserID *int64,
+) error {
+ completionResponse := map[string]any{
+ "redirect": redirectTo,
+ }
+ if authErr != nil {
+ if errors.Is(authErr, service.ErrOAuthInvitationRequired) {
+ completionResponse["error"] = "invitation_required"
+ } else {
+ return authErr
+ }
+ } else if tokenPair != nil {
+ completionResponse["access_token"] = tokenPair.AccessToken
+ completionResponse["refresh_token"] = tokenPair.RefreshToken
+ completionResponse["expires_in"] = tokenPair.ExpiresIn
+ completionResponse["token_type"] = "Bearer"
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: intent,
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: wechatOAuthProviderKey,
+ ProviderSubject: providerSubject,
+ },
+ TargetUserID: targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
+}
+
+func (h *AuthHandler) createWeChatChoicePendingSession(
+ c *gin.Context,
+ identity service.PendingAuthIdentityKey,
+ suggestedEmail string,
+ resolvedEmail string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ compatEmail string,
+ compatEmailUser *dbent.User,
+ forceEmailOnSignup bool,
+) error {
+ suggestionEmail := strings.TrimSpace(suggestedEmail)
+ canonicalEmail := strings.TrimSpace(resolvedEmail)
+ if suggestionEmail == "" {
+ suggestionEmail = canonicalEmail
+ }
+
+ completionResponse := map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "redirect": strings.TrimSpace(redirectTo),
+ "email": suggestionEmail,
+ "resolved_email": canonicalEmail,
+ "existing_account_email": "",
+ "existing_account_bindable": false,
+ "create_account_allowed": true,
+ "force_email_on_signup": forceEmailOnSignup,
+ "choice_reason": "third_party_signup",
+ }
+ if strings.TrimSpace(compatEmail) != "" {
+ completionResponse["compat_email"] = strings.TrimSpace(compatEmail)
+ }
+ if compatEmailUser != nil {
+ completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_bindable"] = true
+ completionResponse["choice_reason"] = "compat_email_match"
+ }
+ if forceEmailOnSignup {
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ }
+
+ resolvedChoiceEmail := suggestionEmail
+ if compatEmailUser != nil {
+ resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identity,
+ ResolvedEmail: resolvedChoiceEmail,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
+}
+
+func (h *AuthHandler) createWeChatBindPendingSession(
+ c *gin.Context,
+ cfg wechatOAuthConfig,
+ providerSubject string,
+ channelSubject string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+) error {
+ currentUser, err := h.readOAuthBindTargetUser(c, wechatOAuthBindUserCookieName)
+ if err != nil {
+ return err
+ }
+ if err := h.ensureWeChatBindOwnership(c.Request.Context(), currentUser.ID, providerSubject, cfg, channelSubject); err != nil {
+ return err
+ }
+ return h.createWeChatPendingSession(
+ c,
+ wechatOAuthIntentBind,
+ providerSubject,
+ currentUser.Email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ nil,
+ nil,
+ ¤tUser.ID,
+ )
+}
+
+func (h *AuthHandler) readOAuthBindTargetUser(c *gin.Context, cookieName string) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ userID, err := h.readOAuthBindUserIDFromCookie(c, cookieName)
+ if err != nil {
+ return nil, infraerrors.Unauthorized("AUTH_REQUIRED", "current user is required to bind wechat account")
+ }
+ userEntity, err := client.User.Get(c.Request.Context(), userID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, infraerrors.Unauthorized("AUTH_REQUIRED", "current user is required to bind wechat account")
+ }
+ return nil, infraerrors.InternalServer("WECHAT_BIND_USER_LOOKUP_FAILED", "failed to load current user").WithCause(err)
+ }
+ return userEntity, nil
+}
+
+func (h *AuthHandler) ensureWeChatBindOwnership(
+ ctx context.Context,
+ userID int64,
+ providerSubject string,
+ cfg wechatOAuthConfig,
+ channelSubject string,
+) error {
+ client := h.entClient()
+ if client == nil {
+ return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ identities, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(providerSubject)),
+ ).
+ All(ctx)
+ if err != nil {
+ return infraerrors.InternalServer("WECHAT_BIND_LOOKUP_FAILED", "failed to inspect wechat identity ownership").WithCause(err)
+ }
+ for _, identity := range identities {
+ if identity != nil && identity.UserID != userID {
+ activeOwner, lookupErr := findActiveUserByID(ctx, client, identity.UserID)
+ if lookupErr != nil {
+ return lookupErr
+ }
+ if activeOwner != nil {
+ return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ }
+ }
+
+ channelSubject = strings.TrimSpace(channelSubject)
+ channelAppID := strings.TrimSpace(cfg.appID)
+ if channelSubject == "" || channelAppID == "" {
+ return nil
+ }
+
+ channels, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(cfg.mode)),
+ authidentitychannel.ChannelAppIDEQ(channelAppID),
+ authidentitychannel.ChannelSubjectEQ(channelSubject),
+ ).
+ WithIdentity().
+ All(ctx)
+ if err != nil {
+ return infraerrors.InternalServer("WECHAT_BIND_CHANNEL_LOOKUP_FAILED", "failed to inspect wechat identity channel ownership").WithCause(err)
+ }
+ for _, channel := range channels {
+ if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
+ activeOwner, lookupErr := findActiveUserByID(ctx, client, channel.Edges.Identity.UserID)
+ if lookupErr != nil {
+ return lookupErr
+ }
+ if activeOwner != nil {
+ return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ }
+ }
+ return nil
+}
+
+func (h *AuthHandler) findWeChatUserByLegacyOpenID(
+ ctx context.Context,
+ identity service.PendingAuthIdentityKey,
+ cfg wechatOAuthConfig,
+ openid string,
+) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ providerType := strings.TrimSpace(identity.ProviderType)
+ providerSubject := strings.TrimSpace(identity.ProviderSubject)
+ providerKeys := wechatCompatibleProviderKeys(identity.ProviderKey)
+ if providerSubject != "" {
+ records, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ ).
+ WithUser().
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ if user, err := singleWeChatIdentityUser(records); err != nil || user != nil {
+ if err != nil || user == nil {
+ return user, err
+ }
+ return findActiveUserByID(ctx, client, user.ID)
+ }
+ }
+
+ openid = strings.TrimSpace(openid)
+ channel := strings.TrimSpace(cfg.mode)
+ channelAppID := strings.TrimSpace(cfg.appID)
+ if openid != "" && channel != "" && channelAppID != "" {
+ records, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(providerType),
+ authidentitychannel.ProviderKeyIn(providerKeys...),
+ authidentitychannel.ChannelEQ(channel),
+ authidentitychannel.ChannelAppIDEQ(channelAppID),
+ authidentitychannel.ChannelSubjectEQ(openid),
+ ).
+ WithIdentity(func(q *dbent.AuthIdentityQuery) {
+ q.WithUser()
+ }).
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
+ }
+ if user, err := singleWeChatChannelUser(records); err != nil || user != nil {
+ if err != nil || user == nil {
+ return user, err
+ }
+ return findActiveUserByID(ctx, client, user.ID)
+ }
+ }
+
+ if openid == "" {
+ return nil, nil
+ }
+
+ records, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(openid),
+ ).
+ WithUser().
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ user, err := singleWeChatIdentityUser(records)
+ if err != nil || user == nil {
+ return user, err
+ }
+ return findActiveUserByID(ctx, client, user.ID)
+}
+
+func wechatCompatibleProviderKeys(providerKey string) []string {
+ preferred := strings.TrimSpace(providerKey)
+ if preferred == "" {
+ preferred = wechatOAuthProviderKey
+ }
+ keys := []string{preferred}
+ if !strings.EqualFold(preferred, wechatOAuthLegacyProviderKey) {
+ keys = append(keys, wechatOAuthLegacyProviderKey)
+ }
+ return keys
+}
+
+func singleWeChatIdentityUser(records []*dbent.AuthIdentity) (*dbent.User, error) {
+ var resolved *dbent.User
+ for _, record := range records {
+ if record == nil || record.Edges.User == nil {
+ continue
+ }
+ if resolved == nil {
+ resolved = record.Edges.User
+ continue
+ }
+ if resolved.ID != record.Edges.User.ID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ }
+ return resolved, nil
+}
+
+func singleWeChatChannelUser(records []*dbent.AuthIdentityChannel) (*dbent.User, error) {
+ var resolved *dbent.User
+ for _, record := range records {
+ if record == nil || record.Edges.Identity == nil || record.Edges.Identity.Edges.User == nil {
+ continue
+ }
+ if resolved == nil {
+ resolved = record.Edges.Identity.Edges.User
+ continue
+ }
+ if resolved.ID != record.Edges.Identity.Edges.User.ID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ }
+ return resolved, nil
+}
+
+func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding(
+ ctx context.Context,
+ userID int64,
+ identity service.PendingAuthIdentityKey,
+ upstreamClaims map[string]any,
+) error {
+ client := h.entClient()
+ if client == nil {
+ return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return infraerrors.InternalServer("AUTH_IDENTITY_BIND_FAILED", "failed to begin wechat identity repair transaction").WithCause(err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ _, err = ensurePendingOAuthIdentityForUser(dbent.NewTxContext(ctx, tx), tx, &dbent.PendingAuthSession{
+ ProviderType: strings.TrimSpace(identity.ProviderType),
+ ProviderKey: strings.TrimSpace(identity.ProviderKey),
+ ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
+ UpstreamIdentityClaims: cloneOAuthMetadata(upstreamClaims),
+ }, userID)
+ if err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) {
+ mode, err := resolveWeChatOAuthMode(rawMode, c)
+ if err != nil {
+ return wechatOAuthConfig{}, err
+ }
+
+ if h == nil || h.settingSvc == nil {
+ return wechatOAuthConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "wechat oauth settings service not ready")
+ }
+
+ apiBaseURL := ""
+ if h != nil && h.settingSvc != nil {
+ settings, err := h.settingSvc.GetAllSettings(ctx)
+ if err == nil && settings != nil {
+ apiBaseURL = strings.TrimSpace(settings.APIBaseURL)
+ }
+ }
+
+ effective, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx)
+ if err != nil {
+ return wechatOAuthConfig{}, err
+ }
+ if !effective.SupportsMode(mode) {
+ return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
+ }
+
+ cfg := wechatOAuthConfig{
+ mode: mode,
+ appID: strings.TrimSpace(effective.AppIDForMode(mode)),
+ appSecret: strings.TrimSpace(effective.AppSecretForMode(mode)),
+ redirectURI: firstNonEmpty(strings.TrimSpace(effective.RedirectURL), resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback")),
+ frontendCallback: firstNonEmpty(strings.TrimSpace(effective.FrontendRedirectURL), wechatOAuthDefaultFrontendCB),
+ scope: effective.ScopeForMode(mode),
+ openEnabled: effective.OpenEnabled,
+ mpEnabled: effective.MPEnabled,
+ }
+
+ switch mode {
+ case "mp":
+ cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize"
+ default:
+ cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect"
+ }
+ if strings.TrimSpace(cfg.redirectURI) == "" {
+ return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured")
+ }
+
+ return cfg, nil
+}
+
+func (cfg wechatOAuthConfig) requiresUnionID() bool {
+ return cfg.openEnabled && cfg.mpEnabled
+}
+
+func (h *AuthHandler) wechatOAuthFrontendCallback(ctx context.Context) string {
+ if h != nil && h.settingSvc != nil {
+ cfg, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx)
+ if err == nil && strings.TrimSpace(cfg.FrontendRedirectURL) != "" {
+ return strings.TrimSpace(cfg.FrontendRedirectURL)
+ }
+ }
+ return wechatOAuthDefaultFrontendCB
+}
+
+func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) {
+ mode := strings.ToLower(strings.TrimSpace(rawMode))
+ if mode == "" {
+ if isWeChatBrowserRequest(c) {
+ return "mp", nil
+ }
+ return "open", nil
+ }
+ if mode != "open" && mode != "mp" {
+ return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp")
+ }
+ return mode, nil
+}
+
+func isWeChatBrowserRequest(c *gin.Context) bool {
+ if c == nil || c.Request == nil {
+ return false
+ }
+ return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger")
+}
+
+func normalizeWeChatOAuthIntent(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "", "login":
+ return wechatOAuthIntentLogin
+ case "bind", "bind_current_user":
+ return wechatOAuthIntentBind
+ case "adopt", "adopt_existing_user_by_email":
+ return wechatOAuthIntentAdoptEmail
+ default:
+ return wechatOAuthIntentLogin
+ }
+}
+
+func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) {
+ u, err := url.Parse(cfg.authorizeURL)
+ if err != nil {
+ return "", fmt.Errorf("parse authorize url: %w", err)
+ }
+ query := u.Query()
+ query.Set("appid", cfg.appID)
+ query.Set("redirect_uri", cfg.redirectURI)
+ query.Set("response_type", "code")
+ query.Set("scope", cfg.scope)
+ query.Set("state", state)
+ u.RawQuery = query.Encode()
+ u.Fragment = "wechat_redirect"
+ return u.String(), nil
+}
+
+func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string {
+ callbackPath = strings.TrimSpace(callbackPath)
+ if callbackPath == "" {
+ return ""
+ }
+
+ if raw := strings.TrimSpace(apiBaseURL); raw != "" {
+ if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" {
+ basePath := strings.TrimRight(parsed.EscapedPath(), "/")
+ targetPath := callbackPath
+ if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") {
+ targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1")
+ } else if basePath != "" {
+ targetPath = basePath + callbackPath
+ }
+ return parsed.Scheme + "://" + parsed.Host + targetPath
+ }
+ }
+
+ if c == nil || c.Request == nil {
+ return ""
+ }
+ scheme := "http"
+ if isRequestHTTPS(c) {
+ scheme = "https"
+ }
+ host := strings.TrimSpace(c.Request.Host)
+ if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" {
+ host = forwardedHost
+ }
+ if host == "" {
+ return ""
+ }
+ return scheme + "://" + host + callbackPath
+}
+
+func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) {
+ tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code)
+ if err != nil {
+ return nil, nil, err
+ }
+ userInfo, err := fetchWeChatUserInfo(ctx, tokenResp)
+ if err != nil {
+ return nil, nil, err
+ }
+ return tokenResp, userInfo, nil
+}
+
+func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) {
+ endpoint, err := url.Parse(wechatOAuthAccessTokenURL)
+ if err != nil {
+ return nil, fmt.Errorf("parse wechat access token url: %w", err)
+ }
+
+ query := endpoint.Query()
+ query.Set("appid", cfg.appID)
+ query.Set("secret", cfg.appSecret)
+ query.Set("code", strings.TrimSpace(code))
+ query.Set("grant_type", "authorization_code")
+ endpoint.RawQuery = query.Encode()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("build wechat access token request: %w", err)
+ }
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request wechat access token: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("read wechat access token response: %w", err)
+ }
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode)
+ }
+
+ var tokenResp wechatOAuthTokenResponse
+ if err := json.Unmarshal(body, &tokenResp); err != nil {
+ return nil, fmt.Errorf("decode wechat access token response: %w", err)
+ }
+ if tokenResp.ErrCode != 0 {
+ return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg))
+ }
+ if strings.TrimSpace(tokenResp.AccessToken) == "" {
+ return nil, fmt.Errorf("wechat access token missing access_token")
+ }
+ return &tokenResp, nil
+}
+
+func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) {
+ if tokenResp == nil {
+ return nil, fmt.Errorf("wechat token response is nil")
+ }
+
+ endpoint, err := url.Parse(wechatOAuthUserInfoURL)
+ if err != nil {
+ return nil, fmt.Errorf("parse wechat userinfo url: %w", err)
+ }
+ query := endpoint.Query()
+ query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken))
+ query.Set("openid", strings.TrimSpace(tokenResp.OpenID))
+ query.Set("lang", "zh_CN")
+ endpoint.RawQuery = query.Encode()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("build wechat userinfo request: %w", err)
+ }
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request wechat userinfo: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("read wechat userinfo response: %w", err)
+ }
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode)
+ }
+
+ var userInfo wechatOAuthUserInfoResponse
+ if err := json.Unmarshal(body, &userInfo); err != nil {
+ return nil, fmt.Errorf("decode wechat userinfo response: %w", err)
+ }
+ if userInfo.ErrCode != 0 {
+ return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg))
+ }
+ return &userInfo, nil
+}
+
+func wechatSyntheticEmail(subject string) string {
+ subject = strings.TrimSpace(subject)
+ if subject == "" {
+ return ""
+ }
+ return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain
+}
+
+func wechatFallbackUsername(subject string) string {
+ subject = strings.TrimSpace(subject)
+ if subject == "" {
+ return "wechat_user"
+ }
+ return "wechat_" + truncateFragmentValue(subject)
+}
+
+func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: value,
+ Path: wechatOAuthCookiePath,
+ MaxAge: maxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func wechatClearCookie(c *gin.Context, name string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: "",
+ Path: wechatOAuthCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func normalizeWeChatPaymentType(raw string) string {
+ switch strings.TrimSpace(raw) {
+ case payment.TypeWxpay, payment.TypeWxpayDirect:
+ return strings.TrimSpace(raw)
+ default:
+ return ""
+ }
+}
+
+func normalizeWeChatPaymentScope(raw string) string {
+ for _, part := range strings.FieldsFunc(strings.TrimSpace(raw), func(r rune) bool {
+ return r == ',' || r == ' ' || r == '\t' || r == '\n' || r == '\r'
+ }) {
+ switch strings.TrimSpace(part) {
+ case "snsapi_userinfo":
+ return "snsapi_userinfo"
+ case "snsapi_base":
+ return "snsapi_base"
+ }
+ }
+ return "snsapi_base"
+}
+
+func normalizeWeChatPaymentRedirectPath(path string) string {
+ path = strings.TrimSpace(path)
+ if path == "" {
+ return wechatPaymentOAuthDefaultTo
+ }
+ if path == "/payment" {
+ return "/purchase"
+ }
+ if strings.HasPrefix(path, "/payment?") {
+ return "/purchase" + strings.TrimPrefix(path, "/payment")
+ }
+ return path
+}
+
+func (h *AuthHandler) resolveWeChatPaymentOAuthCallbackURL(ctx context.Context, c *gin.Context) string {
+ apiBaseURL := ""
+ if h != nil && h.settingSvc != nil {
+ if settings, err := h.settingSvc.GetAllSettings(ctx); err == nil && settings != nil {
+ apiBaseURL = strings.TrimSpace(settings.APIBaseURL)
+ }
+ }
+ return resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/payment/callback")
+}
+
+func encodeWeChatPaymentOAuthContext(ctx wechatPaymentOAuthContext) (string, error) {
+ data, err := json.Marshal(ctx)
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+func decodeWeChatPaymentOAuthContext(raw string) (wechatPaymentOAuthContext, error) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return wechatPaymentOAuthContext{}, nil
+ }
+ var ctx wechatPaymentOAuthContext
+ if err := json.Unmarshal([]byte(raw), &ctx); err != nil {
+ return wechatPaymentOAuthContext{}, err
+ }
+ return ctx, nil
+}
+
+func parseWeChatPaymentPlanID(raw string) int64 {
+ id, _ := strconv.ParseInt(strings.TrimSpace(raw), 10, 64)
+ return id
+}
+
+func wechatPaymentSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: value,
+ Path: wechatPaymentOAuthCookiePath,
+ MaxAge: maxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func wechatPaymentClearCookie(c *gin.Context, name string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: "",
+ Path: wechatPaymentOAuthCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go
new file mode 100644
index 00000000..b3c7786d
--- /dev/null
+++ b/backend/internal/handler/auth_wechat_oauth_test.go
@@ -0,0 +1,1498 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: "wx-open-app",
+ service.SettingKeyWeChatConnectAppSecret: "wx-open-secret",
+ service.SettingKeyWeChatConnectMode: "open",
+ service.SettingKeyWeChatConnectScopes: "snsapi_login",
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
+ defer client.Close()
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
+ c.Request.Host = "api.example.com"
+
+ handler.WeChatOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.NotEmpty(t, location)
+ require.Contains(t, location, "open.weixin.qq.com")
+ require.Contains(t, location, "appid=wx-open-app")
+ require.Contains(t, location, "scope=snsapi_login")
+
+ cookies := recorder.Result().Cookies()
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName))
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName))
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName))
+ require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName))
+}
+
+func TestWeChatOAuthStart_AllowsOpenModeWhenBothCapabilitiesEnabled(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: "wx-shared-app",
+ service.SettingKeyWeChatConnectAppSecret: "wx-shared-secret",
+ service.SettingKeyWeChatConnectMode: "mp",
+ service.SettingKeyWeChatConnectScopes: "snsapi_base",
+ service.SettingKeyWeChatConnectOpenEnabled: "true",
+ service.SettingKeyWeChatConnectMPEnabled: "true",
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
+ defer client.Close()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
+ c.Request.Host = "api.example.com"
+
+ handler.WeChatOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.NotEmpty(t, location)
+ require.Contains(t, location, "open.weixin.qq.com")
+ require.Contains(t, location, "connect/qrconnect")
+ require.Contains(t, location, "scope=snsapi_login")
+}
+
+func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "wechat", session.ProviderType)
+ require.Equal(t, "wechat-main", session.ProviderKey)
+ require.Equal(t, "union-456", session.ProviderSubject)
+ require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail)
+ require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"])
+ require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
+ require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"])
+}
+
+func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMode(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
+ defer client.Close()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.Equal(t, "openid-123", session.ProviderSubject)
+ require.Equal(t, wechatSyntheticEmail("openid-123"), session.ResolvedEmail)
+
+ completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, "third_party_signup", completion["choice_reason"])
+}
+
+func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(wechatSyntheticEmail("union-456")).
+ SetUsername("wechat-existing-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"username": "wechat-existing-user"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, existingUser.Email, session.ResolvedEmail)
+
+ completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+ _, hasRefreshToken := completion["refresh_token"]
+ require.False(t, hasRefreshToken)
+}
+
+func TestWeChatOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-disabled","unionid":"union-disabled","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-disabled","unionid":"union-disabled","nickname":"Disabled WeChat","headimgurl":"https://cdn.example/disabled.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(wechatSyntheticEmail("union-disabled")).
+ SetUsername("disabled-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusDisabled).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("union-disabled").
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-disabled", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-disabled"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_base"}`))
+ return
+ }
+ http.NotFound(w, r)
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback"))
+ defer client.Close()
+ handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
+ handler.cfg.Totp.EncryptionKeyConfigured = true
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-123"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"12.5","order_type":"subscription","plan_id":7}`))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
+ c.Request = req
+
+ handler.WeChatPaymentOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ parsed, err := url.Parse(location)
+ require.NoError(t, err)
+ fragment, err := url.ParseQuery(parsed.Fragment)
+ require.NoError(t, err)
+ require.Equal(t, "/purchase?from=wechat", fragment.Get("redirect"))
+ require.NotEmpty(t, fragment.Get("wechat_resume_token"))
+ require.Empty(t, fragment.Get("openid"))
+ require.Empty(t, fragment.Get("payment_type"))
+ require.Empty(t, fragment.Get("amount"))
+ require.Empty(t, fragment.Get("order_type"))
+ require.Empty(t, fragment.Get("plan_id"))
+
+ claims, err := handler.wechatPaymentResumeService().ParseWeChatPaymentResumeToken(fragment.Get("wechat_resume_token"))
+ require.NoError(t, err)
+ require.Equal(t, "openid-123", claims.OpenID)
+ require.Equal(t, payment.TypeWxpay, claims.PaymentType)
+ require.Equal(t, "12.5", claims.Amount)
+ require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
+ require.EqualValues(t, 7, claims.PlanID)
+ require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
+}
+
+func TestWeChatPaymentOAuthCallbackUsesExplicitPaymentResumeSigningKeyWhenMixedKeysConfigured(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-mixed-key","scope":"snsapi_base"}`))
+ return
+ }
+ http.NotFound(w, r)
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback"))
+ defer client.Close()
+
+ legacyKeyHex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
+ explicitSigningKey := "explicit-payment-resume-signing-key"
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", explicitSigningKey)
+ handler.cfg.Totp.EncryptionKey = legacyKeyHex
+ handler.cfg.Totp.EncryptionKeyConfigured = true
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-mixed", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-mixed"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"18.8","order_type":"subscription","plan_id":9}`))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
+ c.Request = req
+
+ handler.WeChatPaymentOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ parsed, err := url.Parse(location)
+ require.NoError(t, err)
+ fragment, err := url.ParseQuery(parsed.Fragment)
+ require.NoError(t, err)
+
+ token := fragment.Get("wechat_resume_token")
+ require.NotEmpty(t, token)
+
+ claims, err := service.NewPaymentResumeService([]byte(explicitSigningKey)).ParseWeChatPaymentResumeToken(token)
+ require.NoError(t, err)
+ require.Equal(t, "openid-mixed-key", claims.OpenID)
+ require.Equal(t, payment.TypeWxpay, claims.PaymentType)
+ require.Equal(t, "18.8", claims.Amount)
+ require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
+ require.EqualValues(t, 9, claims.PlanID)
+ require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
+
+ _, err = service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")).ParseWeChatPaymentResumeToken(token)
+ require.Error(t, err)
+}
+
+func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) {
+ testCases := []struct {
+ name string
+ mode string
+ appID string
+ appSecret string
+ openID string
+ }{
+ {
+ name: "open",
+ mode: "open",
+ appID: "wx-open-app",
+ appSecret: "wx-open-secret",
+ openID: "openid-open-123",
+ },
+ {
+ name: "mp",
+ mode: "mp",
+ appID: "wx-mp-app",
+ appSecret: "wx-mp-secret",
+ openID: "openid-mp-123",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"` + tc.openID + `","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"` + tc.openID + `","unionid":"union-456","nickname":"Bind Nick","headimgurl":"https://cdn.example/bind.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings(tc.mode, tc.appID, tc.appSecret, "/auth/wechat/callback"))
+ defer client.Close()
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(context.Background())
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, tc.mode))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, wechatOAuthIntentBind, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, currentUser.ID, *session.TargetUserID)
+ require.Equal(t, currentUser.Email, session.ResolvedEmail)
+ require.Equal(t, "union-456", session.ProviderSubject)
+ require.Equal(t, "union-456", session.UpstreamIdentityClaims["subject"])
+ require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
+ require.Equal(t, tc.openID, session.UpstreamIdentityClaims["openid"])
+ require.Equal(t, tc.mode, session.UpstreamIdentityClaims["channel"])
+ require.Equal(t, tc.appID, session.UpstreamIdentityClaims["channel_app_id"])
+ require.Equal(t, tc.openID, session.UpstreamIdentityClaims["channel_subject"])
+
+ completionResponse := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, "/dashboard", completionResponse["redirect"])
+ _, hasAccessToken := completionResponse["access_token"]
+ require.False(t, hasAccessToken)
+ })
+ }
+}
+
+func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"unionid": "union-456"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ ownerIdentity, err := client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("union-owner").
+ SetMetadata(map[string]any{"unionid": "union-owner"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(ownerIdentity.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetChannel("open").
+ SetChannelAppID("wx-open-app").
+ SetChannelSubject("openid-123").
+ SetMetadata(map[string]any{"openid": "openid-123"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthLegacyProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"unionid": "union-456"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, true)
+ defer client.Close()
+
+ ctx := context.Background()
+ redeemRepo := repository.NewRedeemCodeRepository(client)
+ require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{
+ Code: "invite-1",
+ Type: service.RedeemTypeInvitation,
+ Status: service.StatusUnused,
+ }))
+
+ callbackRecorder := httptest.NewRecorder()
+ callbackCtx, _ := gin.CreateTestContext(callbackRecorder)
+ callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ callbackReq.Host = "api.example.com"
+ callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ callbackCtx.Request = callbackReq
+
+ handler.WeChatOAuthCallback(callbackCtx)
+
+ require.Equal(t, http.StatusFound, callbackRecorder.Code)
+ require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+ sessionToken := decodeCookieValueForTest(t, sessionCookie.Value)
+
+ pendingSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(sessionToken)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthPendingChoiceStep, pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["step"])
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`)
+ completeRecorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(completeRecorder)
+ completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ completeReq.Header.Set("Content-Type", "application/json")
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)})
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")})
+ completeCtx.Request = completeReq
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusOK, completeRecorder.Code)
+ responseData := decodeJSONBody(t, completeRecorder)
+ require.Equal(t, "pending_session", responseData["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, responseData["step"])
+ require.Equal(t, true, responseData["adoption_required"])
+ require.Empty(t, responseData["access_token"])
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(pendingSession.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, consumed.ConsumedAt)
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ("wechat-main"),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ channelCount, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ("wechat-main"),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open-app"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, channelCount)
+
+ decisionCount, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, decisionCount)
+}
+
+func TestCompleteWeChatOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("wechat-complete-no-adoption-session").
+ SetIntent("login").
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("wechat-subject-no-adoption").
+ SetResolvedEmail("wechat-subject-no-adoption@wechat-connect.invalid").
+ SetBrowserSessionKey("wechat-browser-no-adoption").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "wechat_user",
+ "suggested_display_name": "WeChat Legacy",
+ "suggested_avatar_url": "https://cdn.example/wechat-legacy.png",
+ "mode": "open",
+ "channel": "open",
+ "channel_app_id": "wx-open-app",
+ "channel_subject": "openid-legacy",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(recorder)
+ completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ completeReq.Header.Set("Content-Type", "application/json")
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-browser-no-adoption")})
+ completeCtx.Request = completeReq
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+ require.NotEmpty(t, responseData["refresh_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "wechat_user", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("wechat-subject-no-adoption"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+}
+
+func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ legacyUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(legacyUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("openid-123").
+ SetMetadata(map[string]any{"openid": "openid-123"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, legacyUser.ID, *session.TargetUserID)
+ require.Equal(t, legacyUser.Email, session.ResolvedEmail)
+
+ repairedIdentity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
+ require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
+
+ openIDIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("openid-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, openIDIdentityCount)
+
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open-app"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, repairedIdentity.ID, channel.IdentityID)
+}
+
+func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("wechat-complete-invalid-session").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-invalid-1").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("wechat-invalid-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "wechat_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "bind_login_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-invalid-browser")})
+ completeCtx.Request = req
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("wechat-complete-choice-session").
+ SetIntent("login").
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("wechat-choice-subject-1").
+ SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid").
+ SetBrowserSessionKey("wechat-choice-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "wechat_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": oauthPendingChoiceStep,
+ "redirect": "/dashboard",
+ "email": "fresh@example.com",
+ "resolved_email": "fresh@example.com",
+ "force_email_on_signup": true,
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")})
+ completeCtx.Request = req
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.Equal(t, "pending_session", responseData["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, responseData["step"])
+ require.Equal(t, true, responseData["force_email_on_signup"])
+ require.Empty(t, responseData["access_token"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy Canonical","headimgurl":"https://cdn.example/legacy-canonical.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ legacyUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(legacyUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthLegacyProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"unionid": "union-456"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, legacyUser.ID, *session.TargetUserID)
+ require.Equal(t, legacyUser.Email, session.ResolvedEmail)
+
+ repairedIdentity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
+ require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
+
+ legacyIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthLegacyProviderKey),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, legacyIdentityCount)
+
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open-app"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, repairedIdentity.ID, channel.IdentityID)
+}
+
+func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
+ return newWeChatOAuthTestHandlerWithSettings(t, invitationEnabled, nil)
+}
+
+func wechatOAuthTestSettings(mode, appID, secret, frontendRedirect string) map[string]string {
+ return map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: appID,
+ service.SettingKeyWeChatConnectAppSecret: secret,
+ service.SettingKeyWeChatConnectMode: mode,
+ service.SettingKeyWeChatConnectScopes: service.DefaultWeChatConnectScopesForMode(mode),
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: frontendRedirect,
+ }
+}
+
+func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, extraSettings map[string]string) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+
+ userRepo := &oauthPendingFlowUserRepo{client: client}
+ redeemRepo := repository.NewRedeemCodeRepository(client)
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ }
+ values := map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
+ }
+ for key, value := range wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "/auth/wechat/callback") {
+ values[key] = value
+ }
+ for key, value := range extraSettings {
+ values[key] = value
+ }
+ settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{values: values}, cfg)
+
+ authSvc := service.NewAuthService(
+ client,
+ userRepo,
+ redeemRepo,
+ &wechatOAuthRefreshTokenCacheStub{},
+ cfg,
+ settingSvc,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ return &AuthHandler{
+ authService: authSvc,
+ settingSvc: settingSvc,
+ cfg: cfg,
+ }, client
+}
+
+type wechatOAuthSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ return nil, service.ErrSettingNotFound
+}
+
+func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", service.ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error {
+ return nil
+}
+
+func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ result[key] = value
+ }
+ }
+ return result, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return nil
+}
+
+func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ result := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ result[key] = value
+ }
+ return result, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error {
+ return nil
+}
+
+type wechatOAuthRefreshTokenCacheStub struct{}
+
+func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
diff --git a/backend/internal/handler/available_channel_handler.go b/backend/internal/handler/available_channel_handler.go
new file mode 100644
index 00000000..8982b80d
--- /dev/null
+++ b/backend/internal/handler/available_channel_handler.go
@@ -0,0 +1,283 @@
+package handler
+
+import (
+ "sort"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// AvailableChannelHandler 处理用户侧「可用渠道」查询。
+//
+// 用户侧接口委托 ChannelService.ListAvailable,并在返回前做三层过滤:
+// 1. 行过滤:只保留状态为 Active 且与当前用户可访问分组有交集的渠道;
+// 2. 分组过滤:渠道的 Groups 只保留用户可访问的那些;
+// 3. 平台过滤:渠道的 SupportedModels 只保留平台在用户可见 Groups 中出现过的模型,
+// 防止"渠道同时挂在 antigravity / anthropic 两个平台的分组上,用户只访问
+// antigravity,却看到 anthropic 模型"这类跨平台信息泄漏;
+// 4. 字段白名单:仅返回用户需要的字段(省略 BillingModelSource / RestrictModels
+// / 内部 ID / Status 等管理字段)。
+type AvailableChannelHandler struct {
+ channelService *service.ChannelService
+ apiKeyService *service.APIKeyService
+ settingService *service.SettingService
+}
+
+// NewAvailableChannelHandler 创建用户侧可用渠道 handler。
+func NewAvailableChannelHandler(
+ channelService *service.ChannelService,
+ apiKeyService *service.APIKeyService,
+ settingService *service.SettingService,
+) *AvailableChannelHandler {
+ return &AvailableChannelHandler{
+ channelService: channelService,
+ apiKeyService: apiKeyService,
+ settingService: settingService,
+ }
+}
+
+// featureEnabled 返回 available-channels 开关是否启用。默认关闭(opt-in)。
+func (h *AvailableChannelHandler) featureEnabled(c *gin.Context) bool {
+ if h.settingService == nil {
+ return false
+ }
+ return h.settingService.GetAvailableChannelsRuntime(c.Request.Context()).Enabled
+}
+
+// userAvailableGroup 用户可见的分组概要(白名单字段)。
+//
+// 前端据此区分专属 vs 公开分组(IsExclusive)、订阅 vs 标准分组(SubscriptionType,
+// 订阅视觉加深),并用 RateMultiplier 作为默认倍率;用户专属倍率前端走
+// /groups/rates,和 API 密钥页面保持一致。
+type userAvailableGroup struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Platform string `json:"platform"`
+ SubscriptionType string `json:"subscription_type"`
+ RateMultiplier float64 `json:"rate_multiplier"`
+ IsExclusive bool `json:"is_exclusive"`
+}
+
+// userSupportedModelPricing 用户可见的定价字段白名单。
+type userSupportedModelPricing struct {
+ BillingMode string `json:"billing_mode"`
+ InputPrice *float64 `json:"input_price"`
+ OutputPrice *float64 `json:"output_price"`
+ CacheWritePrice *float64 `json:"cache_write_price"`
+ CacheReadPrice *float64 `json:"cache_read_price"`
+ ImageOutputPrice *float64 `json:"image_output_price"`
+ PerRequestPrice *float64 `json:"per_request_price"`
+ Intervals []userPricingIntervalDTO `json:"intervals"`
+}
+
+// userPricingIntervalDTO 定价区间白名单(去掉内部 ID、SortOrder 等前端不渲染的字段)。
+type userPricingIntervalDTO struct {
+ MinTokens int `json:"min_tokens"`
+ MaxTokens *int `json:"max_tokens"`
+ TierLabel string `json:"tier_label,omitempty"`
+ InputPrice *float64 `json:"input_price"`
+ OutputPrice *float64 `json:"output_price"`
+ CacheWritePrice *float64 `json:"cache_write_price"`
+ CacheReadPrice *float64 `json:"cache_read_price"`
+ PerRequestPrice *float64 `json:"per_request_price"`
+}
+
+// userSupportedModel 用户可见的支持模型条目。
+type userSupportedModel struct {
+ Name string `json:"name"`
+ Platform string `json:"platform"`
+ Pricing *userSupportedModelPricing `json:"pricing"`
+}
+
+// userChannelPlatformSection 单渠道内某个平台的子视图:用户可见的分组 + 该平台
+// 支持的模型。按 platform 聚合后让前端可以把渠道名作为 row-group 一次渲染,
+// 后面的平台行按 sections 顺序铺开。
+type userChannelPlatformSection struct {
+ Platform string `json:"platform"`
+ Groups []userAvailableGroup `json:"groups"`
+ SupportedModels []userSupportedModel `json:"supported_models"`
+}
+
+// userAvailableChannel 用户可见的渠道条目(白名单字段)。
+//
+// 每个渠道聚合为一条记录,内嵌 platforms 子数组:每个 section 对应一个平台,
+// 包含该平台的 groups 和 supported_models。
+type userAvailableChannel struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Platforms []userChannelPlatformSection `json:"platforms"`
+}
+
+// List 列出当前用户可见的「可用渠道」。
+// GET /api/v1/channels/available
+func (h *AvailableChannelHandler) List(c *gin.Context) {
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ // Feature 未启用时返回空数组(不暴露渠道信息)。检查放在认证之后,
+ // 保持与未开关前的 401 行为一致:未登录先 401,登录后再按开关决定。
+ if !h.featureEnabled(c) {
+ response.Success(c, []userAvailableChannel{})
+ return
+ }
+
+ userGroups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ allowedGroupIDs := make(map[int64]struct{}, len(userGroups))
+ for i := range userGroups {
+ allowedGroupIDs[userGroups[i].ID] = struct{}{}
+ }
+
+ channels, err := h.channelService.ListAvailable(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ out := make([]userAvailableChannel, 0, len(channels))
+ for _, ch := range channels {
+ if ch.Status != service.StatusActive {
+ continue
+ }
+ visibleGroups := filterUserVisibleGroups(ch.Groups, allowedGroupIDs)
+ if len(visibleGroups) == 0 {
+ continue
+ }
+ sections := buildPlatformSections(ch, visibleGroups)
+ if len(sections) == 0 {
+ continue
+ }
+ out = append(out, userAvailableChannel{
+ Name: ch.Name,
+ Description: ch.Description,
+ Platforms: sections,
+ })
+ }
+
+ response.Success(c, out)
+}
+
+// buildPlatformSections 把一个渠道按 visibleGroups 的平台集合拆成有序的 section 列表:
+// 每个 section 对应一个平台,只包含该平台的 groups 和 supported_models。
+// 输出按 platform 字母序稳定排序,便于前端等效比较与回归测试。
+func buildPlatformSections(
+ ch service.AvailableChannel,
+ visibleGroups []userAvailableGroup,
+) []userChannelPlatformSection {
+ groupsByPlatform := make(map[string][]userAvailableGroup, 4)
+ for _, g := range visibleGroups {
+ if g.Platform == "" {
+ continue
+ }
+ groupsByPlatform[g.Platform] = append(groupsByPlatform[g.Platform], g)
+ }
+ if len(groupsByPlatform) == 0 {
+ return nil
+ }
+
+ platforms := make([]string, 0, len(groupsByPlatform))
+ for p := range groupsByPlatform {
+ platforms = append(platforms, p)
+ }
+ sort.Strings(platforms)
+
+ sections := make([]userChannelPlatformSection, 0, len(platforms))
+ for _, platform := range platforms {
+ platformSet := map[string]struct{}{platform: {}}
+ sections = append(sections, userChannelPlatformSection{
+ Platform: platform,
+ Groups: groupsByPlatform[platform],
+ SupportedModels: toUserSupportedModels(ch.SupportedModels, platformSet),
+ })
+ }
+ return sections
+}
+
+// filterUserVisibleGroups 仅保留用户可访问的分组。
+func filterUserVisibleGroups(
+ groups []service.AvailableGroupRef,
+ allowed map[int64]struct{},
+) []userAvailableGroup {
+ visible := make([]userAvailableGroup, 0, len(groups))
+ for _, g := range groups {
+ if _, ok := allowed[g.ID]; !ok {
+ continue
+ }
+ visible = append(visible, userAvailableGroup{
+ ID: g.ID,
+ Name: g.Name,
+ Platform: g.Platform,
+ SubscriptionType: g.SubscriptionType,
+ RateMultiplier: g.RateMultiplier,
+ IsExclusive: g.IsExclusive,
+ })
+ }
+ return visible
+}
+
+// toUserSupportedModels 将 service 层支持模型转换为用户 DTO(字段白名单)。
+// 仅保留平台在 allowedPlatforms 中的条目,防止跨平台模型信息泄漏。
+// allowedPlatforms 为 nil 时不做平台过滤(保留全部,供测试或明确无过滤场景使用)。
+func toUserSupportedModels(
+ src []service.SupportedModel,
+ allowedPlatforms map[string]struct{},
+) []userSupportedModel {
+ out := make([]userSupportedModel, 0, len(src))
+ for i := range src {
+ m := src[i]
+ if allowedPlatforms != nil {
+ if _, ok := allowedPlatforms[m.Platform]; !ok {
+ continue
+ }
+ }
+ out = append(out, userSupportedModel{
+ Name: m.Name,
+ Platform: m.Platform,
+ Pricing: toUserPricing(m.Pricing),
+ })
+ }
+ return out
+}
+
+// toUserPricing 将 service 层定价转换为用户 DTO;入参为 nil 时返回 nil。
+func toUserPricing(p *service.ChannelModelPricing) *userSupportedModelPricing {
+ if p == nil {
+ return nil
+ }
+ intervals := make([]userPricingIntervalDTO, 0, len(p.Intervals))
+ for _, iv := range p.Intervals {
+ intervals = append(intervals, userPricingIntervalDTO{
+ MinTokens: iv.MinTokens,
+ MaxTokens: iv.MaxTokens,
+ TierLabel: iv.TierLabel,
+ InputPrice: iv.InputPrice,
+ OutputPrice: iv.OutputPrice,
+ CacheWritePrice: iv.CacheWritePrice,
+ CacheReadPrice: iv.CacheReadPrice,
+ PerRequestPrice: iv.PerRequestPrice,
+ })
+ }
+ billingMode := string(p.BillingMode)
+ if billingMode == "" {
+ billingMode = string(service.BillingModeToken)
+ }
+ return &userSupportedModelPricing{
+ BillingMode: billingMode,
+ InputPrice: p.InputPrice,
+ OutputPrice: p.OutputPrice,
+ CacheWritePrice: p.CacheWritePrice,
+ CacheReadPrice: p.CacheReadPrice,
+ ImageOutputPrice: p.ImageOutputPrice,
+ PerRequestPrice: p.PerRequestPrice,
+ Intervals: intervals,
+ }
+}
diff --git a/backend/internal/handler/available_channel_handler_test.go b/backend/internal/handler/available_channel_handler_test.go
new file mode 100644
index 00000000..0a7ce6c4
--- /dev/null
+++ b/backend/internal/handler/available_channel_handler_test.go
@@ -0,0 +1,157 @@
+//go:build unit
+
+package handler
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserAvailableChannel_Unauthenticated401(t *testing.T) {
+ // 没有 AuthSubject 注入时,handler 应返回 401 且不触达 service 依赖。
+ gin.SetMode(gin.TestMode)
+ h := &AvailableChannelHandler{} // nil services — 401 路径不会调用它们
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/channels/available", nil)
+
+ h.List(c)
+
+ require.Equal(t, http.StatusUnauthorized, w.Code)
+}
+
+func TestFilterUserVisibleGroups_IntersectionOnly(t *testing.T) {
+ // 渠道挂在 {g1, g2, g3},用户只允许 {g1, g3} —— 响应必须仅含 g1/g3。
+ groups := []service.AvailableGroupRef{
+ {ID: 1, Name: "g1", Platform: "anthropic"},
+ {ID: 2, Name: "g2", Platform: "anthropic"},
+ {ID: 3, Name: "g3", Platform: "openai"},
+ }
+ allowed := map[int64]struct{}{1: {}, 3: {}}
+
+ visible := filterUserVisibleGroups(groups, allowed)
+ require.Len(t, visible, 2)
+ ids := []int64{visible[0].ID, visible[1].ID}
+ require.ElementsMatch(t, []int64{1, 3}, ids)
+}
+
+func TestToUserSupportedModels_FiltersByAllowedPlatforms(t *testing.T) {
+ // 用户可访问分组只覆盖 anthropic;anthropic 平台的模型保留,openai 模型被剔除。
+ src := []service.SupportedModel{
+ {Name: "claude-sonnet-4-6", Platform: "anthropic", Pricing: nil},
+ {Name: "gpt-4o", Platform: "openai", Pricing: nil},
+ }
+ allowed := map[string]struct{}{"anthropic": {}}
+ out := toUserSupportedModels(src, allowed)
+ require.Len(t, out, 1)
+ require.Equal(t, "claude-sonnet-4-6", out[0].Name)
+}
+
+func TestToUserSupportedModels_NilAllowedPlatformsKeepsAll(t *testing.T) {
+ // 显式传 nil allowedPlatforms 表示不做过滤。
+ src := []service.SupportedModel{
+ {Name: "a", Platform: "anthropic"},
+ {Name: "b", Platform: "openai"},
+ }
+ require.Len(t, toUserSupportedModels(src, nil), 2)
+}
+
+func TestUserAvailableChannel_FieldWhitelist(t *testing.T) {
+ // 通过序列化 userAvailableChannel 结构体验证响应形状:
+ // 只有 name / description / platforms;不含管理端字段。
+ row := userAvailableChannel{
+ Name: "ch",
+ Description: "d",
+ Platforms: []userChannelPlatformSection{
+ {
+ Platform: "anthropic",
+ Groups: []userAvailableGroup{{ID: 1, Name: "g1", Platform: "anthropic"}},
+ SupportedModels: []userSupportedModel{},
+ },
+ },
+ }
+ raw, err := json.Marshal(row)
+ require.NoError(t, err)
+ var decoded map[string]any
+ require.NoError(t, json.Unmarshal(raw, &decoded))
+
+ for _, key := range []string{"id", "status", "billing_model_source", "restrict_models"} {
+ _, exists := decoded[key]
+ require.Falsef(t, exists, "user DTO must not expose %q", key)
+ }
+ for _, key := range []string{"name", "description", "platforms"} {
+ _, exists := decoded[key]
+ require.Truef(t, exists, "user DTO must expose %q", key)
+ }
+
+ // 验证 section 的字段(platform / groups / supported_models)。
+ rawSection, err := json.Marshal(row.Platforms[0])
+ require.NoError(t, err)
+ var sectionDecoded map[string]any
+ require.NoError(t, json.Unmarshal(rawSection, §ionDecoded))
+ for _, key := range []string{"platform", "groups", "supported_models"} {
+ _, exists := sectionDecoded[key]
+ require.Truef(t, exists, "platform section must expose %q", key)
+ }
+
+ // Group DTO 暴露区分专属/公开、订阅类型、默认倍率所需的字段,
+ // 前端据此渲染 GroupBadge 并与 API 密钥页保持一致的视觉。
+ rawGroup, err := json.Marshal(row.Platforms[0].Groups[0])
+ require.NoError(t, err)
+ var groupDecoded map[string]any
+ require.NoError(t, json.Unmarshal(rawGroup, &groupDecoded))
+ for _, key := range []string{"id", "name", "platform", "subscription_type", "rate_multiplier", "is_exclusive"} {
+ _, exists := groupDecoded[key]
+ require.Truef(t, exists, "group DTO must expose %q", key)
+ }
+
+ // pricing interval 白名单:不应暴露 id / sort_order。
+ pricing := toUserPricing(&service.ChannelModelPricing{
+ BillingMode: service.BillingModeToken,
+ Intervals: []service.PricingInterval{
+ {ID: 7, MinTokens: 0, MaxTokens: nil, SortOrder: 3},
+ },
+ })
+ require.NotNil(t, pricing)
+ require.Len(t, pricing.Intervals, 1)
+ rawIv, err := json.Marshal(pricing.Intervals[0])
+ require.NoError(t, err)
+ var ivDecoded map[string]any
+ require.NoError(t, json.Unmarshal(rawIv, &ivDecoded))
+ for _, key := range []string{"id", "pricing_id", "sort_order"} {
+ _, exists := ivDecoded[key]
+ require.Falsef(t, exists, "user pricing interval must not expose %q", key)
+ }
+}
+
+func TestBuildPlatformSections_GroupsByPlatform(t *testing.T) {
+ // 一个渠道横跨 anthropic / openai / 空平台:应该生成 2 个 section,
+ // 按 platform 字母序排序,各自 groups 和 supported_models 只含同平台条目。
+ ch := service.AvailableChannel{
+ Name: "ch",
+ SupportedModels: []service.SupportedModel{
+ {Name: "claude-sonnet-4-6", Platform: "anthropic"},
+ {Name: "gpt-4o", Platform: "openai"},
+ },
+ }
+ visible := []userAvailableGroup{
+ {ID: 1, Name: "g-openai", Platform: "openai"},
+ {ID: 2, Name: "g-ant", Platform: "anthropic"},
+ {ID: 3, Name: "g-empty", Platform: ""},
+ }
+ sections := buildPlatformSections(ch, visible)
+ require.Len(t, sections, 2)
+ require.Equal(t, "anthropic", sections[0].Platform)
+ require.Equal(t, "openai", sections[1].Platform)
+ require.Len(t, sections[0].Groups, 1)
+ require.Equal(t, int64(2), sections[0].Groups[0].ID)
+ require.Len(t, sections[0].SupportedModels, 1)
+ require.Equal(t, "claude-sonnet-4-6", sections[0].SupportedModels[0].Name)
+}
diff --git a/backend/internal/handler/channel_monitor_user_handler.go b/backend/internal/handler/channel_monitor_user_handler.go
new file mode 100644
index 00000000..cc36b334
--- /dev/null
+++ b/backend/internal/handler/channel_monitor_user_handler.go
@@ -0,0 +1,176 @@
+package handler
+
+import (
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/handler/admin"
+ "github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// ChannelMonitorUserHandler 渠道监控用户只读 handler。
+type ChannelMonitorUserHandler struct {
+ monitorService *service.ChannelMonitorService
+ settingService *service.SettingService
+}
+
+// NewChannelMonitorUserHandler 创建 handler。
+// settingService 用于每次请求前读取功能开关;关闭时 List/GetStatus 直接返回空/404。
+func NewChannelMonitorUserHandler(
+ monitorService *service.ChannelMonitorService,
+ settingService *service.SettingService,
+) *ChannelMonitorUserHandler {
+ return &ChannelMonitorUserHandler{
+ monitorService: monitorService,
+ settingService: settingService,
+ }
+}
+
+// featureEnabled 返回当前渠道监控功能是否开启。
+// settingService 为 nil(测试场景)视为启用。
+func (h *ChannelMonitorUserHandler) featureEnabled(c *gin.Context) bool {
+ if h.settingService == nil {
+ return true
+ }
+ return h.settingService.GetChannelMonitorRuntime(c.Request.Context()).Enabled
+}
+
+// --- Response ---
+
+type channelMonitorUserListItem struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ GroupName string `json:"group_name"`
+ PrimaryModel string `json:"primary_model"`
+ PrimaryStatus string `json:"primary_status"`
+ PrimaryLatencyMs *int `json:"primary_latency_ms"`
+ PrimaryPingLatencyMs *int `json:"primary_ping_latency_ms"`
+ Availability7d float64 `json:"availability_7d"`
+ ExtraModels []dto.ChannelMonitorExtraModelStatus `json:"extra_models"`
+ Timeline []channelMonitorUserTimelinePoint `json:"timeline"`
+}
+
+// channelMonitorUserTimelinePoint 主模型最近一次检测的 timeline 点。
+// 仅用于用户视图 list 响应,admin 视图不使用。
+type channelMonitorUserTimelinePoint struct {
+ Status string `json:"status"`
+ LatencyMs *int `json:"latency_ms"`
+ PingLatencyMs *int `json:"ping_latency_ms"`
+ CheckedAt string `json:"checked_at"`
+}
+
+type channelMonitorUserDetailResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ GroupName string `json:"group_name"`
+ Models []channelMonitorUserModelStat `json:"models"`
+}
+
+type channelMonitorUserModelStat struct {
+ Model string `json:"model"`
+ LatestStatus string `json:"latest_status"`
+ LatestLatencyMs *int `json:"latest_latency_ms"`
+ Availability7d float64 `json:"availability_7d"`
+ Availability15d float64 `json:"availability_15d"`
+ Availability30d float64 `json:"availability_30d"`
+ AvgLatency7dMs *int `json:"avg_latency_7d_ms"`
+}
+
+func userMonitorViewToItem(v *service.UserMonitorView) channelMonitorUserListItem {
+ extras := make([]dto.ChannelMonitorExtraModelStatus, 0, len(v.ExtraModels))
+ for _, e := range v.ExtraModels {
+ extras = append(extras, dto.ChannelMonitorExtraModelStatus{
+ Model: e.Model,
+ Status: e.Status,
+ LatencyMs: e.LatencyMs,
+ })
+ }
+ timeline := make([]channelMonitorUserTimelinePoint, 0, len(v.Timeline))
+ for _, p := range v.Timeline {
+ timeline = append(timeline, channelMonitorUserTimelinePoint{
+ Status: p.Status,
+ LatencyMs: p.LatencyMs,
+ PingLatencyMs: p.PingLatencyMs,
+ CheckedAt: p.CheckedAt.UTC().Format(time.RFC3339),
+ })
+ }
+ return channelMonitorUserListItem{
+ ID: v.ID,
+ Name: v.Name,
+ Provider: v.Provider,
+ GroupName: v.GroupName,
+ PrimaryModel: v.PrimaryModel,
+ PrimaryStatus: v.PrimaryStatus,
+ PrimaryLatencyMs: v.PrimaryLatencyMs,
+ PrimaryPingLatencyMs: v.PrimaryPingLatencyMs,
+ Availability7d: v.Availability7d,
+ ExtraModels: extras,
+ Timeline: timeline,
+ }
+}
+
+func userMonitorDetailToResponse(d *service.UserMonitorDetail) *channelMonitorUserDetailResponse {
+ models := make([]channelMonitorUserModelStat, 0, len(d.Models))
+ for _, m := range d.Models {
+ models = append(models, channelMonitorUserModelStat{
+ Model: m.Model,
+ LatestStatus: m.LatestStatus,
+ LatestLatencyMs: m.LatestLatencyMs,
+ Availability7d: m.Availability7d,
+ Availability15d: m.Availability15d,
+ Availability30d: m.Availability30d,
+ AvgLatency7dMs: m.AvgLatency7dMs,
+ })
+ }
+ return &channelMonitorUserDetailResponse{
+ ID: d.ID,
+ Name: d.Name,
+ Provider: d.Provider,
+ GroupName: d.GroupName,
+ Models: models,
+ }
+}
+
+// --- Handlers ---
+
+// List GET /api/v1/channel-monitors
+func (h *ChannelMonitorUserHandler) List(c *gin.Context) {
+ if !h.featureEnabled(c) {
+ response.Success(c, gin.H{"items": []channelMonitorUserListItem{}})
+ return
+ }
+ views, err := h.monitorService.ListUserView(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ items := make([]channelMonitorUserListItem, 0, len(views))
+ for _, v := range views {
+ items = append(items, userMonitorViewToItem(v))
+ }
+ response.Success(c, gin.H{"items": items})
+}
+
+// GetStatus GET /api/v1/channel-monitors/:id/status
+func (h *ChannelMonitorUserHandler) GetStatus(c *gin.Context) {
+ if !h.featureEnabled(c) {
+ response.ErrorFrom(c, service.ErrChannelMonitorNotFound)
+ return
+ }
+ // 复用 admin.ParseChannelMonitorID 保持错误码与日志一致。
+ id, ok := admin.ParseChannelMonitorID(c)
+ if !ok {
+ return
+ }
+ detail, err := h.monitorService.GetUserDetail(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, userMonitorDetailToResponse(detail))
+}
diff --git a/backend/internal/handler/dto/channel_monitor.go b/backend/internal/handler/dto/channel_monitor.go
new file mode 100644
index 00000000..3c0c5e11
--- /dev/null
+++ b/backend/internal/handler/dto/channel_monitor.go
@@ -0,0 +1,10 @@
+package dto
+
+// ChannelMonitorExtraModelStatus 渠道监控附加模型最近一次状态。
+// 同时被 admin handler(List 响应)与 user handler(List 响应)复用,
+// 字段必须保持一致以保证前端拿到统一结构。
+type ChannelMonitorExtraModelStatus struct {
+ Model string `json:"model"`
+ Status string `json:"status"`
+ LatencyMs *int `json:"latency_ms"`
+}
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index d2ccb8d6..f7503c2e 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -21,6 +21,7 @@ func UserFromServiceShallow(u *service.User) *User {
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: u.AllowedGroups,
+ LastActiveAt: u.LastActiveAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
BalanceNotifyEnabled: u.BalanceNotifyEnabled,
@@ -28,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User {
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
TotalRecharged: u.TotalRecharged,
+ RPMLimit: u.RPMLimit,
}
}
@@ -66,6 +68,7 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return &AdminUser{
User: *base,
Notes: u.Notes,
+ LastUsedAt: u.LastUsedAt,
GroupRates: u.GroupRates,
}
}
@@ -182,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group {
AllowMessagesDispatch: g.AllowMessagesDispatch,
RequireOAuthOnly: g.RequireOAuthOnly,
RequirePrivacySet: g.RequirePrivacySet,
+ RPMLimit: g.RPMLimit,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
diff --git a/backend/internal/handler/dto/public_settings_injection_schema_test.go b/backend/internal/handler/dto/public_settings_injection_schema_test.go
new file mode 100644
index 00000000..428fed3d
--- /dev/null
+++ b/backend/internal/handler/dto/public_settings_injection_schema_test.go
@@ -0,0 +1,70 @@
+package dto
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+// TestPublicSettingsInjectionPayload_SchemaDoesNotDrift guarantees the SSR
+// injection struct exposes every JSON field consumed by the frontend.
+//
+// Why this test exists: before we extracted a named PublicSettingsInjectionPayload
+// type, the inline struct was manually kept in sync with dto.PublicSettings and
+// drifted — ChannelMonitorEnabled / AvailableChannelsEnabled were missing, which
+// made the frontend read `undefined` on refresh and hide the "可用渠道" menu
+// until the async /api/v1/settings/public round-trip finished.
+//
+// This test compares the two JSON-tag sets and fails if injection is missing
+// any field that dto.PublicSettings exposes. Adding a new feature flag with
+// only a DTO entry will fail this test until the injection struct is updated.
+//
+// Intentional exclusions (fields present on dto.PublicSettings that SSR does
+// not need to inject) are listed in `dtoOnlyFields` below with a reason.
+func TestPublicSettingsInjectionPayload_SchemaDoesNotDrift(t *testing.T) {
+ injection := jsonTags(reflect.TypeOf(service.PublicSettingsInjectionPayload{}))
+ dtoKeys := jsonTags(reflect.TypeOf(PublicSettings{}))
+
+ // Fields that legitimately live only on the DTO. Keep tiny; document each.
+ dtoOnlyFields := map[string]string{
+ // sora_client_enabled is an upstream-only field the fork does not surface.
+ "sora_client_enabled": "upstream-only field, not used on this fork",
+ // force_email_on_third_party_signup lives on the DTO but is not injected via SSR.
+ "force_email_on_third_party_signup": "auth-source default, not a feature flag",
+ }
+
+ var missing []string
+ for key := range dtoKeys {
+ if _, ok := injection[key]; ok {
+ continue
+ }
+ if _, allowed := dtoOnlyFields[key]; allowed {
+ continue
+ }
+ missing = append(missing, key)
+ }
+ if len(missing) > 0 {
+ t.Fatalf("service.PublicSettingsInjectionPayload is missing JSON fields present on dto.PublicSettings: %s\n"+
+ "add the field to PublicSettingsInjectionPayload (and GetPublicSettingsForInjection), or "+
+ "document the exclusion in dtoOnlyFields with a reason.", strings.Join(missing, ", "))
+ }
+}
+
+func jsonTags(t reflect.Type) map[string]struct{} {
+ out := make(map[string]struct{})
+ for i := 0; i < t.NumField(); i++ {
+ f := t.Field(i)
+ tag := f.Tag.Get("json")
+ if tag == "" || tag == "-" {
+ continue
+ }
+ name := strings.SplitN(tag, ",", 2)[0]
+ if name == "" {
+ continue
+ }
+ out[name] = struct{}{}
+ }
+ return out
+}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 3659e79b..492be170 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -51,6 +51,23 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
+ WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
+ WeChatConnectAppID string `json:"wechat_connect_app_id"`
+ WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"`
+ WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"`
+ WeChatConnectOpenAppSecretConfigured bool `json:"wechat_connect_open_app_secret_configured"`
+ WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"`
+ WeChatConnectMPAppSecretConfigured bool `json:"wechat_connect_mp_app_secret_configured"`
+ WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"`
+ WeChatConnectMobileAppSecretConfigured bool `json:"wechat_connect_mobile_app_secret_configured"`
+ WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"`
+ WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"`
+ WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"`
+ WeChatConnectMode string `json:"wechat_connect_mode"`
+ WeChatConnectScopes string `json:"wechat_connect_scopes"`
+ WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"`
+ WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"`
+
OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
OIDCConnectClientID string `json:"oidc_connect_client_id"`
@@ -89,9 +106,14 @@ type SystemSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
- DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
+ AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
+ AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"`
+ AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"`
+ AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"`
+ DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
+ DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -120,13 +142,23 @@ type SystemSettings struct {
BackendModeEnabled bool `json:"backend_mode_enabled"`
// Gateway forwarding behavior
- EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
- EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
- EnableCCHSigning bool `json:"enable_cch_signing"`
+ EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
+ EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
+ EnableCCHSigning bool `json:"enable_cch_signing"`
+ EnableAnthropicCacheTTL1hInjection bool `json:"enable_anthropic_cache_ttl_1h_injection"`
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource string `json:"payment_visible_method_alipay_source"`
+ PaymentVisibleMethodWxpaySource string `json:"payment_visible_method_wxpay_source"`
+ PaymentVisibleMethodAlipayEnabled bool `json:"payment_visible_method_alipay_enabled"`
+ PaymentVisibleMethodWxpayEnabled bool `json:"payment_visible_method_wxpay_enabled"`
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled bool `json:"openai_advanced_scheduler_enabled"`
+
// Payment configuration
PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"`
@@ -157,6 +189,19 @@ type SystemSettings struct {
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
+
+ // Channel Monitor feature switch
+ ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
+
+ // Available Channels feature switch (user-facing aggregate view)
+ AvailableChannelsEnabled bool `json:"available_channels_enabled"`
+
+ // Affiliate (邀请返利) feature switch
+ AffiliateEnabled bool `json:"affiliate_enabled"`
+
+ // OpenAI fast/flex policy
+ OpenAIFastPolicySettings *OpenAIFastPolicySettings `json:"openai_fast_policy_settings,omitempty"`
}
type DefaultSubscriptionSetting struct {
@@ -167,6 +212,7 @@ type DefaultSubscriptionSetting struct {
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
@@ -189,6 +235,10 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
+ WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
+ WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
+ WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
SoraClientEnabled bool `json:"sora_client_enabled"`
@@ -199,6 +249,13 @@ type PublicSettings struct {
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
+
+ ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
+
+ AvailableChannelsEnabled bool `json:"available_channels_enabled"`
+
+ AffiliateEnabled bool `json:"affiliate_enabled"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO
@@ -241,6 +298,22 @@ type BetaPolicySettings struct {
Rules []BetaPolicyRule `json:"rules"`
}
+// OpenAIFastPolicyRule OpenAI fast/flex 策略规则 DTO
+type OpenAIFastPolicyRule struct {
+ ServiceTier string `json:"service_tier"`
+ Action string `json:"action"`
+ Scope string `json:"scope"`
+ ErrorMessage string `json:"error_message,omitempty"`
+ ModelWhitelist []string `json:"model_whitelist,omitempty"`
+ FallbackAction string `json:"fallback_action,omitempty"`
+ FallbackErrorMessage string `json:"fallback_error_message,omitempty"`
+}
+
+// OpenAIFastPolicySettings OpenAI fast 策略配置 DTO
+type OpenAIFastPolicySettings struct {
+ Rules []OpenAIFastPolicyRule `json:"rules"`
+}
+
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
// Returns empty slice on empty/invalid input.
func ParseCustomMenuItems(raw string) []CustomMenuItem {
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 8c1e166f..5cc2f8e4 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -7,16 +7,17 @@ import (
)
type User struct {
- ID int64 `json:"id"`
- Email string `json:"email"`
- Username string `json:"username"`
- Role string `json:"role"`
- Balance float64 `json:"balance"`
- Concurrency int `json:"concurrency"`
- Status string `json:"status"`
- AllowedGroups []int64 `json:"allowed_groups"`
- CreatedAt time.Time `json:"created_at"`
- UpdatedAt time.Time `json:"updated_at"`
+ ID int64 `json:"id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+ Role string `json:"role"`
+ Balance float64 `json:"balance"`
+ Concurrency int `json:"concurrency"`
+ Status string `json:"status"`
+ AllowedGroups []int64 `json:"allowed_groups"`
+ LastActiveAt *time.Time `json:"last_active_at,omitempty"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
// 余额不足通知
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
@@ -25,6 +26,9 @@ type User struct {
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
TotalRecharged float64 `json:"total_recharged"`
+ // RPMLimit 用户级每分钟请求数上限(0 = 不限制),仅在所用分组未设置 rpm_limit 时作为兜底生效。
+ RPMLimit int `json:"rpm_limit"`
+
APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
@@ -34,7 +38,8 @@ type User struct {
type AdminUser struct {
User
- Notes string `json:"notes"`
+ Notes string `json:"notes"`
+ LastUsedAt *time.Time `json:"last_used_at"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
@@ -106,6 +111,9 @@ type Group struct {
RequireOAuthOnly bool `json:"require_oauth_only"`
RequirePrivacySet bool `json:"require_privacy_set"`
+ // RPMLimit 分组级每分钟请求数上限(0 = 不限制),设置后覆盖用户级 rpm_limit。
+ RPMLimit int `json:"rpm_limit"`
+
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
diff --git a/backend/internal/handler/dto/user_mapper_activity_test.go b/backend/internal/handler/dto/user_mapper_activity_test.go
new file mode 100644
index 00000000..a17f0ce4
--- /dev/null
+++ b/backend/internal/handler/dto/user_mapper_activity_test.go
@@ -0,0 +1,33 @@
+package dto
+
+import (
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserFromServiceAdmin_MapsActivityTimestamps(t *testing.T) {
+ t.Parallel()
+
+ lastLoginAt := time.Date(2026, time.April, 20, 10, 0, 0, 0, time.UTC)
+ lastActiveAt := lastLoginAt.Add(15 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(45 * time.Minute)
+
+ out := UserFromServiceAdmin(&service.User{
+ ID: 42,
+ Email: "admin@example.com",
+ Username: "admin",
+ Role: service.RoleAdmin,
+ Status: service.StatusActive,
+ LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
+ })
+
+ require.NotNil(t, out)
+ require.NotNil(t, out.LastActiveAt)
+ require.NotNil(t, out.LastUsedAt)
+ require.WithinDuration(t, lastActiveAt, *out.LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *out.LastUsedAt, time.Second)
+}
diff --git a/backend/internal/handler/endpoint.go b/backend/internal/handler/endpoint.go
index a897bc40..db29618a 100644
--- a/backend/internal/handler/endpoint.go
+++ b/backend/internal/handler/endpoint.go
@@ -15,10 +15,12 @@ import (
// ──────────────────────────────────────────────────────────
const (
- EndpointMessages = "/v1/messages"
- EndpointChatCompletions = "/v1/chat/completions"
- EndpointResponses = "/v1/responses"
- EndpointGeminiModels = "/v1beta/models"
+ EndpointMessages = "/v1/messages"
+ EndpointChatCompletions = "/v1/chat/completions"
+ EndpointResponses = "/v1/responses"
+ EndpointImagesGenerations = "/v1/images/generations"
+ EndpointImagesEdits = "/v1/images/edits"
+ EndpointGeminiModels = "/v1beta/models"
)
// gin.Context keys used by the middleware and helpers below.
@@ -44,6 +46,10 @@ func NormalizeInboundEndpoint(path string) string {
return EndpointChatCompletions
case strings.Contains(path, EndpointMessages):
return EndpointMessages
+ case strings.Contains(path, EndpointImagesGenerations) || strings.Contains(path, "/images/generations"):
+ return EndpointImagesGenerations
+ case strings.Contains(path, EndpointImagesEdits) || strings.Contains(path, "/images/edits"):
+ return EndpointImagesEdits
case strings.Contains(path, EndpointResponses):
return EndpointResponses
case strings.Contains(path, EndpointGeminiModels):
@@ -69,6 +75,9 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
switch platform {
case service.PlatformOpenAI:
+ if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
+ return inbound
+ }
// OpenAI forwards everything to the Responses API.
// Preserve subresource suffix (e.g. /v1/responses/compact).
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
diff --git a/backend/internal/handler/endpoint_test.go b/backend/internal/handler/endpoint_test.go
index 1519bc9e..369c5fa7 100644
--- a/backend/internal/handler/endpoint_test.go
+++ b/backend/internal/handler/endpoint_test.go
@@ -25,12 +25,16 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
{"/v1/messages", EndpointMessages},
{"/v1/chat/completions", EndpointChatCompletions},
{"/v1/responses", EndpointResponses},
+ {"/v1/images/generations", EndpointImagesGenerations},
+ {"/v1/images/edits", EndpointImagesEdits},
{"/v1beta/models", EndpointGeminiModels},
// Prefixed paths (antigravity, openai).
{"/antigravity/v1/messages", EndpointMessages},
{"/openai/v1/responses", EndpointResponses},
{"/openai/v1/responses/compact", EndpointResponses},
+ {"/openai/v1/images/generations", EndpointImagesGenerations},
+ {"/openai/v1/images/edits", EndpointImagesEdits},
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
// Gin route patterns with wildcards.
@@ -73,6 +77,8 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
+ {"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations},
+ {"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits},
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index f5eff8c9..7b082b07 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -259,6 +262,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
+ // [DEBUG-STICKY] 打印会话 hash 生成结果
+ reqLog.Info("sticky.session_hash_generated",
+ zap.String("session_hash", sessionHash),
+ zap.String("metadata_user_id_raw", parsedReq.MetadataUserID),
+ )
+
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
platform := ""
if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
@@ -275,6 +284,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
+ // [DEBUG-STICKY] 打印粘性会话查询结果
+ reqLog.Info("sticky.cache_lookup",
+ zap.String("session_key", sessionKey),
+ zap.Int64("bound_account_id", sessionBoundAccountID),
+ )
if sessionBoundAccountID > 0 {
prefetchedGroupID := int64(0)
if apiKey.GroupID != nil {
@@ -283,6 +297,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
+ } else {
+ reqLog.Info("sticky.no_session_key", zap.String("session_hash", sessionHash))
}
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
@@ -301,6 +317,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
+ reqLog.Warn("gateway.select_account_no_available",
+ zap.String("model", reqModel),
+ zap.Int64p("group_id", apiKey.GroupID),
+ zap.String("platform", platform),
+ zap.Error(err),
+ )
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
@@ -344,6 +366,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
+ reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
+ zap.Int64("account_id", account.ID),
+ zap.String("model", reqModel),
+ zap.String("platform", platform),
+ )
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@@ -522,9 +549,22 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
+ reqLog.Info("sticky.selecting_account",
+ zap.String("session_key", sessionKey),
+ zap.Int64("sticky_bound_account_id", sessionBoundAccountID),
+ zap.Bool("has_bound_session", hasBoundSession),
+ zap.Int("failed_account_count", len(fs.FailedAccountIDs)),
+ )
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
+ reqLog.Warn("gateway.select_account_no_available",
+ zap.String("model", reqModel),
+ zap.Int64p("group_id", currentAPIKey.GroupID),
+ zap.String("platform", platform),
+ zap.Bool("fallback_used", fallbackUsed),
+ zap.Error(err),
+ )
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
@@ -548,6 +588,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
+ // [DEBUG-STICKY] 打印账号选择结果
+ reqLog.Info("sticky.account_selected",
+ zap.Int64("selected_account_id", account.ID),
+ zap.String("account_name", account.Name),
+ zap.Bool("slot_acquired", selection.Acquired),
+ zap.Bool("has_wait_plan", selection.WaitPlan != nil),
+ zap.Int64("sticky_bound_account_id", sessionBoundAccountID),
+ zap.Bool("sticky_honored", sessionBoundAccountID > 0 && sessionBoundAccountID == account.ID),
+ )
+
// 检查请求拦截(预热请求、SUGGESTION MODE等)
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
@@ -568,6 +618,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
+ reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
+ zap.Int64("account_id", account.ID),
+ zap.String("model", reqModel),
+ zap.String("platform", platform),
+ )
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@@ -609,6 +664,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// Slot acquired: no longer waiting in queue.
releaseWait()
+ reqLog.Info("sticky.bind_after_wait",
+ zap.String("session_key", sessionKey),
+ zap.Int64("account_id", account.ID),
+ )
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
@@ -735,7 +794,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -800,6 +862,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
+ // 绑定粘性会话(成功转发后绑定/刷新)
+ // - 无现有绑定(首次请求):创建绑定
+ // - 选中账号与粘性账号一致:刷新 TTL
+ // - 粘性账号因负载/RPM 被跳过、选中了其他账号:不覆盖原绑定,
+ // 下次请求粘性账号恢复后仍可命中
+ if sessionKey != "" && (sessionBoundAccountID == 0 || sessionBoundAccountID == account.ID) {
+ if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
+ reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ }
+ }
+
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
@@ -1441,7 +1514,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility(订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.errorResponse(c, status, code, message)
return
}
@@ -1684,25 +1760,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
c.JSON(http.StatusOK, response)
}
-func billingErrorDetails(err error) (status int, code, message string) {
+func billingErrorDetails(err error) (status int, code, message string, retryAfter int) {
if errors.Is(err, service.ErrBillingServiceUnavailable) {
msg := pkgerrors.Message(err)
if msg == "" {
msg = "Billing service temporarily unavailable. Please retry later."
}
- return http.StatusServiceUnavailable, "billing_service_error", msg
+ return http.StatusServiceUnavailable, "billing_service_error", msg, 0
}
if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) {
msg := pkgerrors.Message(err)
- return http.StatusTooManyRequests, "rate_limit_exceeded", msg
+ return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
}
if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) {
msg := pkgerrors.Message(err)
- return http.StatusTooManyRequests, "rate_limit_exceeded", msg
+ return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
}
if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) {
msg := pkgerrors.Message(err)
- return http.StatusTooManyRequests, "rate_limit_exceeded", msg
+ return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
+ }
+ // 用户/分组 RPM 超限统一映射为 HTTP 429;保留与其它 rate_limit 一致的错误码便于客户端分类。
+ // 返回 Retry-After 秒数(当前分钟剩余秒数),让 SDK 自动退避。
+ if errors.Is(err, service.ErrGroupRPMExceeded) || errors.Is(err, service.ErrUserRPMExceeded) {
+ msg := pkgerrors.Message(err)
+ retrySeconds := 60 - int(time.Now().Unix()%60)
+ return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds
}
msg := pkgerrors.Message(err)
if msg == "" {
@@ -1712,7 +1795,7 @@ func billingErrorDetails(err error) (status int, code, message string) {
).Warn("gateway.billing_error_missing_message")
msg = "Billing error"
}
- return http.StatusForbidden, "billing_error", msg
+ return http.StatusForbidden, "billing_error", msg, 0
}
func (h *GatewayHandler) metadataBridgeEnabled() bool {
diff --git a/backend/internal/handler/gateway_handler_billing_error_test.go b/backend/internal/handler/gateway_handler_billing_error_test.go
new file mode 100644
index 00000000..e8a88802
--- /dev/null
+++ b/backend/internal/handler/gateway_handler_billing_error_test.go
@@ -0,0 +1,54 @@
+package handler
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests(t *testing.T) {
+ status, code, msg, retryAfter := billingErrorDetails(service.ErrGroupRPMExceeded)
+ require.Equal(t, http.StatusTooManyRequests, status)
+ require.Equal(t, "rate_limit_exceeded", code)
+ require.NotEmpty(t, msg)
+ require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
+ require.LessOrEqual(t, retryAfter, 60)
+}
+
+func TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests(t *testing.T) {
+ status, code, msg, retryAfter := billingErrorDetails(service.ErrUserRPMExceeded)
+ require.Equal(t, http.StatusTooManyRequests, status)
+ require.Equal(t, "rate_limit_exceeded", code)
+ require.NotEmpty(t, msg)
+ require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
+ require.LessOrEqual(t, retryAfter, 60)
+}
+
+func TestBillingErrorDetails_APIKeyRateLimitStillMaps(t *testing.T) {
+ // 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。
+ for _, err := range []error{
+ service.ErrAPIKeyRateLimit5hExceeded,
+ service.ErrAPIKeyRateLimit1dExceeded,
+ service.ErrAPIKeyRateLimit7dExceeded,
+ } {
+ status, code, _, _ := billingErrorDetails(err)
+ require.Equal(t, http.StatusTooManyRequests, status, "status for %v", err)
+ require.Equal(t, "rate_limit_exceeded", code)
+ }
+}
+
+func TestBillingErrorDetails_BillingServiceUnavailableMapsTo503(t *testing.T) {
+ status, code, _, retryAfter := billingErrorDetails(service.ErrBillingServiceUnavailable)
+ require.Equal(t, http.StatusServiceUnavailable, status)
+ require.Equal(t, "billing_service_error", code)
+ require.Equal(t, 0, retryAfter, "non-RPM errors should not set Retry-After")
+}
+
+func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) {
+ status, code, msg, _ := billingErrorDetails(service.ErrInsufficientBalance)
+ require.Equal(t, http.StatusForbidden, status)
+ require.Equal(t, "billing_error", code)
+ require.NotEmpty(t, msg)
+}
diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go
index be267332..4290e54b 100644
--- a/backend/internal/handler/gateway_handler_chat_completions.go
+++ b/backend/internal/handler/gateway_handler_chat_completions.go
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
+ "strconv"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
@@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.chatCompletionsErrorResponse(c, status, code, message)
return
}
diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go
index e908eb9e..683cf2b7 100644
--- a/backend/internal/handler/gateway_handler_responses.go
+++ b/backend/internal/handler/gateway_handler_responses.go
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
+ "strconv"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
@@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.responsesErrorResponse(c, status, code, message)
return
}
diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
index 1fdc46ba..57554cf9 100644
--- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
+++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
@@ -50,6 +50,9 @@ func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
return true, nil
}
+func (f *fakeSchedulerCache) UnlockBucket(_ context.Context, _ service.SchedulerBucket) error {
+ return nil
+}
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
@@ -173,7 +176,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
cfg := &config.Config{RunMode: config.RunModeSimple}
- billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
+ billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index d200c17c..2a34e3f0 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -9,6 +9,7 @@ import (
"errors"
"net/http"
"regexp"
+ "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
@@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
- status, _, message := billingErrorDetails(err)
+ status, _, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
googleError(c, status, message)
return
}
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index 906a74f1..13e3ac88 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -6,50 +6,55 @@ import (
// AdminHandlers contains all admin-related HTTP handlers
type AdminHandlers struct {
- Dashboard *admin.DashboardHandler
- User *admin.UserHandler
- Group *admin.GroupHandler
- Account *admin.AccountHandler
- Announcement *admin.AnnouncementHandler
- DataManagement *admin.DataManagementHandler
- Backup *admin.BackupHandler
- OAuth *admin.OAuthHandler
- OpenAIOAuth *admin.OpenAIOAuthHandler
- GeminiOAuth *admin.GeminiOAuthHandler
- AntigravityOAuth *admin.AntigravityOAuthHandler
- Proxy *admin.ProxyHandler
- Redeem *admin.RedeemHandler
- Promo *admin.PromoHandler
- Setting *admin.SettingHandler
- Ops *admin.OpsHandler
- System *admin.SystemHandler
- Subscription *admin.SubscriptionHandler
- Usage *admin.UsageHandler
- UserAttribute *admin.UserAttributeHandler
- ErrorPassthrough *admin.ErrorPassthroughHandler
- TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
- APIKey *admin.AdminAPIKeyHandler
- ScheduledTest *admin.ScheduledTestHandler
- Channel *admin.ChannelHandler
- Payment *admin.PaymentHandler
+ Dashboard *admin.DashboardHandler
+ User *admin.UserHandler
+ Group *admin.GroupHandler
+ Account *admin.AccountHandler
+ Announcement *admin.AnnouncementHandler
+ DataManagement *admin.DataManagementHandler
+ Backup *admin.BackupHandler
+ OAuth *admin.OAuthHandler
+ OpenAIOAuth *admin.OpenAIOAuthHandler
+ GeminiOAuth *admin.GeminiOAuthHandler
+ AntigravityOAuth *admin.AntigravityOAuthHandler
+ Proxy *admin.ProxyHandler
+ Redeem *admin.RedeemHandler
+ Promo *admin.PromoHandler
+ Setting *admin.SettingHandler
+ Ops *admin.OpsHandler
+ System *admin.SystemHandler
+ Subscription *admin.SubscriptionHandler
+ Usage *admin.UsageHandler
+ UserAttribute *admin.UserAttributeHandler
+ ErrorPassthrough *admin.ErrorPassthroughHandler
+ TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
+ APIKey *admin.AdminAPIKeyHandler
+ ScheduledTest *admin.ScheduledTestHandler
+ Channel *admin.ChannelHandler
+ ChannelMonitor *admin.ChannelMonitorHandler
+ ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
+ Payment *admin.PaymentHandler
+ Affiliate *admin.AffiliateHandler
}
// Handlers contains all HTTP handlers
type Handlers struct {
- Auth *AuthHandler
- User *UserHandler
- APIKey *APIKeyHandler
- Usage *UsageHandler
- Redeem *RedeemHandler
- Subscription *SubscriptionHandler
- Announcement *AnnouncementHandler
- Admin *AdminHandlers
- Gateway *GatewayHandler
- OpenAIGateway *OpenAIGatewayHandler
- Setting *SettingHandler
- Totp *TotpHandler
- Payment *PaymentHandler
- PaymentWebhook *PaymentWebhookHandler
+ Auth *AuthHandler
+ User *UserHandler
+ APIKey *APIKeyHandler
+ Usage *UsageHandler
+ Redeem *RedeemHandler
+ Subscription *SubscriptionHandler
+ Announcement *AnnouncementHandler
+ ChannelMonitor *ChannelMonitorUserHandler
+ Admin *AdminHandlers
+ Gateway *GatewayHandler
+ OpenAIGateway *OpenAIGatewayHandler
+ Setting *SettingHandler
+ Totp *TotpHandler
+ Payment *PaymentHandler
+ PaymentWebhook *PaymentWebhookHandler
+ AvailableChannel *AvailableChannelHandler
}
// BuildInfo contains build-time information
diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go
index 991cbb91..f395970a 100644
--- a/backend/internal/handler/openai_chat_completions.go
+++ b/backend/internal/handler/openai_chat_completions.go
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
+ "strconv"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
@@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -126,6 +130,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
+ false,
)
if err != nil {
reqLog.Warn("openai_chat_completions.account_select_failed",
@@ -149,6 +154,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
+ false,
)
if err == nil && selection != nil {
c.Set("openai_chat_completions_fallback_model", defaultModel)
diff --git a/backend/internal/handler/openai_gateway_compact_log_test.go b/backend/internal/handler/openai_gateway_compact_log_test.go
index 062f318b..e18509b4 100644
--- a/backend/internal/handler/openai_gateway_compact_log_test.go
+++ b/backend/internal/handler/openai_gateway_compact_log_test.go
@@ -116,7 +116,7 @@ func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
c.Set(opsModelKey, "gpt-5.3-codex")
c.Set(opsAccountIDKey, int64(123))
c.Header("x-request-id", "rid-compact-ok")
@@ -142,7 +142,7 @@ func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
c.Status(http.StatusBadGateway)
h := &OpenAIGatewayHandler{}
@@ -180,7 +180,7 @@ func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
c.Request.Header.Set("Content-Type", "application/json")
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
h := &OpenAIGatewayHandler{}
h.Responses(c)
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 5319b55d..7676ffa3 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -187,6 +187,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
return
}
+ reqLog.Warn("openai.request_validation_failed",
+ zap.String("reason", "previous_response_id_requires_wsv2"),
+ )
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id is only supported on Responses WebSocket v2")
+ return
}
setOpsRequestContext(c, reqModel, reqStream, body)
@@ -223,13 +228,17 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody)
+ requireCompact := isOpenAIRemoteCompactPath(c)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
@@ -248,6 +257,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
+ requireCompact,
)
if err != nil {
reqLog.Warn("openai.account_select_failed",
@@ -255,6 +265,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
+ if errors.Is(err, service.ErrNoAvailableCompactAccounts) {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted)
+ return
+ }
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
@@ -589,7 +603,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -633,6 +650,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
currentRoutingModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
+ false,
)
if err != nil {
reqLog.Warn("openai_messages.account_select_failed",
@@ -856,7 +874,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context,
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_call_id"),
)
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2")
return false
}
if validation.HasItemReferenceForAllCallIDs {
@@ -866,7 +884,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context,
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2")
return false
}
@@ -1156,6 +1174,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
reqModel,
nil,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
+ false,
)
if err != nil {
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go
index d299fb81..8ecee59a 100644
--- a/backend/internal/handler/openai_gateway_handler_test.go
+++ b/backend/internal/handler/openai_gateway_handler_test.go
@@ -494,6 +494,64 @@ func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
}
+func TestOpenAIResponses_RejectsHTTPContinuationPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123456","input":[{"type":"input_text","text":"hello"}]}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "Responses WebSocket v2")
+ require.Contains(t, w.Body.String(), "previous_response_id")
+}
+
+func TestOpenAIResponses_FunctionCallOutputHTTPGuidanceDoesNotSuggestPreviousResponseReuse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,"input":[{"type":"function_call_output","output":"{}"}]}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "Responses WebSocket v2")
+ require.NotContains(t, w.Body.String(), "reuse previous_response_id")
+}
+
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go
new file mode 100644
index 00000000..4d0078a7
--- /dev/null
+++ b/backend/internal/handler/openai_images.go
@@ -0,0 +1,299 @@
+package handler
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+
+ pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "go.uber.org/zap"
+)
+
+// Images handles OpenAI Images API requests.
+// POST /v1/images/generations
+// POST /v1/images/edits
+func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
+ streamStarted := false
+ defer h.recoverResponsesPanic(c, &streamStarted)
+
+ requestStart := time.Now()
+
+ apiKey, ok := middleware2.GetAPIKeyFromContext(c)
+ if !ok {
+ h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
+ return
+ }
+
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
+ return
+ }
+ reqLog := requestLogger(
+ c,
+ "handler.openai_gateway.images",
+ zap.Int64("user_id", subject.UserID),
+ zap.Int64("api_key_id", apiKey.ID),
+ zap.Any("group_id", apiKey.GroupID),
+ )
+ if !h.ensureResponsesDependencies(c, reqLog) {
+ return
+ }
+
+ body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
+ if err != nil {
+ if maxErr, ok := extractMaxBytesError(err); ok {
+ h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
+ return
+ }
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
+ return
+ }
+ if len(body) == 0 {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
+ return
+ }
+
+ if isMultipartImagesContentType(c.GetHeader("Content-Type")) {
+ setOpsRequestContext(c, "", false, nil)
+ } else {
+ setOpsRequestContext(c, "", false, body)
+ }
+
+ parsed, err := h.gatewayService.ParseOpenAIImagesRequest(c, body)
+ if err != nil {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
+ return
+ }
+
+ reqLog = reqLog.With(
+ zap.String("model", parsed.Model),
+ zap.Bool("stream", parsed.Stream),
+ zap.Bool("multipart", parsed.Multipart),
+ zap.String("capability", string(parsed.RequiredCapability)),
+ )
+
+ if parsed.Multipart {
+ setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
+ } else {
+ setOpsRequestContext(c, parsed.Model, parsed.Stream, body)
+ }
+ setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsed.Stream, false)))
+
+ channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, parsed.Model)
+
+ if h.errorPassthroughService != nil {
+ service.BindErrorPassthroughService(c, h.errorPassthroughService)
+ }
+
+ subscription, _ := middleware2.GetSubscriptionFromContext(c)
+
+ service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
+ routingStart := time.Now()
+
+ userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, parsed.Stream, &streamStarted, reqLog)
+ if !acquired {
+ return
+ }
+ if userReleaseFunc != nil {
+ defer userReleaseFunc()
+ }
+
+ if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
+ reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
+ h.handleStreamingAwareError(c, status, code, message, streamStarted)
+ return
+ }
+
+ sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body)
+
+ maxAccountSwitches := h.maxAccountSwitches
+ switchCount := 0
+ failedAccountIDs := make(map[int64]struct{})
+ sameAccountRetryCount := make(map[int64]int)
+ var lastFailoverErr *service.UpstreamFailoverError
+
+ for {
+ reqLog.Debug("openai.images.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
+ selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForImages(
+ c.Request.Context(),
+ apiKey.GroupID,
+ sessionHash,
+ parsed.Model,
+ failedAccountIDs,
+ parsed.RequiredCapability,
+ )
+ if err != nil {
+ reqLog.Warn("openai.images.account_select_failed",
+ zap.Error(err),
+ zap.Int("excluded_account_count", len(failedAccountIDs)),
+ )
+ if len(failedAccountIDs) == 0 {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
+ return
+ }
+ if lastFailoverErr != nil {
+ h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
+ } else {
+ h.handleFailoverExhaustedSimple(c, 502, streamStarted)
+ }
+ return
+ }
+ if selection == nil || selection.Account == nil {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
+ return
+ }
+
+ reqLog.Debug("openai.images.account_schedule_decision",
+ zap.String("layer", scheduleDecision.Layer),
+ zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
+ zap.Int("candidate_count", scheduleDecision.CandidateCount),
+ zap.Int("top_k", scheduleDecision.TopK),
+ zap.Int64("latency_ms", scheduleDecision.LatencyMs),
+ zap.Float64("load_skew", scheduleDecision.LoadSkew),
+ )
+
+ account := selection.Account
+ sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
+ reqLog.Debug("openai.images.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
+ setOpsSelectedAccount(c, account.ID, account.Platform)
+
+ accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, parsed.Stream, &streamStarted, reqLog)
+ if !acquired {
+ return
+ }
+
+ service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
+ forwardStart := time.Now()
+ result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
+ forwardDurationMs := time.Since(forwardStart).Milliseconds()
+ if accountReleaseFunc != nil {
+ accountReleaseFunc()
+ }
+ upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
+ responseLatencyMs := forwardDurationMs
+ if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
+ responseLatencyMs = forwardDurationMs - upstreamLatencyMs
+ }
+ service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
+ if err == nil && result != nil && result.FirstTokenMs != nil {
+ service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
+ }
+ if err != nil {
+ var failoverErr *service.UpstreamFailoverError
+ if errors.As(err, &failoverErr) {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ if failoverErr.RetryableOnSameAccount {
+ retryLimit := account.GetPoolModeRetryCount()
+ if sameAccountRetryCount[account.ID] < retryLimit {
+ sameAccountRetryCount[account.ID]++
+ reqLog.Warn("openai.images.pool_mode_same_account_retry",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("retry_limit", retryLimit),
+ zap.Int("retry_count", sameAccountRetryCount[account.ID]),
+ )
+ select {
+ case <-c.Request.Context().Done():
+ return
+ case <-time.After(sameAccountRetryDelay):
+ }
+ continue
+ }
+ }
+ h.gatewayService.RecordOpenAIAccountSwitch()
+ failedAccountIDs[account.ID] = struct{}{}
+ lastFailoverErr = failoverErr
+ if switchCount >= maxAccountSwitches {
+ h.handleFailoverExhausted(c, failoverErr, streamStarted)
+ return
+ }
+ switchCount++
+ reqLog.Warn("openai.images.upstream_failover_switching",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("switch_count", switchCount),
+ zap.Int("max_switches", maxAccountSwitches),
+ )
+ continue
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
+ fields := []zap.Field{
+ zap.Int64("account_id", account.ID),
+ zap.Bool("fallback_error_response_written", wroteFallback),
+ zap.Error(err),
+ }
+ if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
+ reqLog.Warn("openai.images.forward_failed", fields...)
+ return
+ }
+ reqLog.Error("openai.images.forward_failed", fields...)
+ return
+ }
+
+ if result != nil {
+ if account.Type == service.AccountTypeOAuth {
+ h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
+ } else {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
+ }
+
+ userAgent := c.GetHeader("User-Agent")
+ clientIP := ip.GetClientIP(c)
+ requestPayloadHash := service.HashUsageRequestPayload(body)
+ if parsed.Multipart {
+ requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed()))
+ }
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: account,
+ Subscription: subscription,
+ InboundEndpoint: GetInboundEndpoint(c),
+ UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
+ UserAgent: userAgent,
+ IPAddress: clientIP,
+ RequestPayloadHash: requestPayloadHash,
+ APIKeyService: h.apiKeyService,
+ ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel),
+ }); err != nil {
+ logger.L().With(
+ zap.String("component", "handler.openai_gateway.images"),
+ zap.Int64("user_id", subject.UserID),
+ zap.Int64("api_key_id", apiKey.ID),
+ zap.Any("group_id", apiKey.GroupID),
+ zap.String("model", parsed.Model),
+ zap.Int64("account_id", account.ID),
+ ).Error("openai.images.record_usage_failed", zap.Error(err))
+ }
+ })
+
+ reqLog.Debug("openai.images.request_completed",
+ zap.Int64("account_id", account.ID),
+ zap.Int("switch_count", switchCount),
+ )
+ return
+ }
+}
+
+func isMultipartImagesContentType(contentType string) bool {
+ return strings.HasPrefix(strings.ToLower(strings.TrimSpace(contentType)), "multipart/form-data")
+}
diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go
index 90e90dd0..93554912 100644
--- a/backend/internal/handler/ops_error_logger.go
+++ b/backend/internal/handler/ops_error_logger.go
@@ -1068,7 +1068,7 @@ func guessPlatformFromPath(path string) string {
return service.PlatformAntigravity
case strings.HasPrefix(p, "/v1beta/"):
return service.PlatformGemini
- case strings.Contains(p, "/responses"):
+ case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"):
return service.PlatformOpenAI
default:
return ""
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index 1ddb8ae2..09580442 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -1,9 +1,14 @@
package handler
import (
+ "fmt"
"strconv"
"strings"
+ "time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -202,10 +207,18 @@ func (h *PaymentHandler) GetLimits(c *gin.Context) {
// CreateOrderRequest is the request body for creating a payment order.
type CreateOrderRequest struct {
- Amount float64 `json:"amount"`
- PaymentType string `json:"payment_type" binding:"required"`
- OrderType string `json:"order_type"`
- PlanID int64 `json:"plan_id"`
+ Amount float64 `json:"amount"`
+ PaymentType string `json:"payment_type" binding:"required"`
+ OpenID string `json:"openid"`
+ WechatResumeToken string `json:"wechat_resume_token"`
+ ReturnURL string `json:"return_url"`
+ PaymentSource string `json:"payment_source"`
+ OrderType string `json:"order_type"`
+ PlanID int64 `json:"plan_id"`
+ // IsMobile lets the frontend declare its mobile status directly. When
+ // nil we fall back to User-Agent heuristics (which miss iPadOS / some
+ // embedded browsers that strip the "Mobile" keyword).
+ IsMobile *bool `json:"is_mobile,omitempty"`
}
// CreateOrder creates a new payment order.
@@ -221,17 +234,36 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
+ if strings.TrimSpace(req.WechatResumeToken) != "" {
+ claims, err := h.paymentService.ParseWeChatPaymentResumeToken(req.WechatResumeToken)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyWeChatPaymentResumeClaims(&req, claims); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+ mobile := isMobile(c)
+ if req.IsMobile != nil {
+ mobile = *req.IsMobile
+ }
result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{
- UserID: subject.UserID,
- Amount: req.Amount,
- PaymentType: req.PaymentType,
- ClientIP: c.ClientIP(),
- IsMobile: isMobile(c),
- SrcHost: c.Request.Host,
- SrcURL: c.Request.Referer(),
- OrderType: req.OrderType,
- PlanID: req.PlanID,
+ UserID: subject.UserID,
+ Amount: req.Amount,
+ PaymentType: req.PaymentType,
+ OpenID: req.OpenID,
+ ClientIP: c.ClientIP(),
+ IsMobile: mobile,
+ IsWeChatBrowser: isWeChatBrowser(c),
+ SrcHost: c.Request.Host,
+ SrcURL: c.Request.Referer(),
+ ReturnURL: req.ReturnURL,
+ PaymentSource: req.PaymentSource,
+ OrderType: req.OrderType,
+ PlanID: req.PlanID,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -240,6 +272,44 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
response.Success(c, result)
}
+func applyWeChatPaymentResumeClaims(req *CreateOrderRequest, claims *service.WeChatPaymentResumeClaims) error {
+ if req == nil || claims == nil {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume context is missing")
+ }
+ openid := strings.TrimSpace(claims.OpenID)
+ if openid == "" {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
+ }
+
+ paymentType := service.NormalizeVisibleMethod(claims.PaymentType)
+ if paymentType == "" {
+ paymentType = payment.TypeWxpay
+ }
+ if req.PaymentType != "" {
+ requestPaymentType := service.NormalizeVisibleMethod(req.PaymentType)
+ if requestPaymentType != "" && requestPaymentType != paymentType {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payment type mismatch")
+ }
+ }
+ req.PaymentType = paymentType
+ req.OpenID = openid
+
+ if strings.TrimSpace(claims.Amount) != "" {
+ amount, err := strconv.ParseFloat(strings.TrimSpace(claims.Amount), 64)
+ if err != nil || amount <= 0 {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", fmt.Sprintf("invalid resume amount: %s", claims.Amount))
+ }
+ req.Amount = amount
+ }
+ if claims.OrderType != "" {
+ req.OrderType = claims.OrderType
+ }
+ if claims.PlanID > 0 {
+ req.PlanID = claims.PlanID
+ }
+ return nil
+}
+
// GetMyOrders returns the authenticated user's orders.
// GET /api/v1/payment/orders/my
func (h *PaymentHandler) GetMyOrders(c *gin.Context) {
@@ -260,7 +330,7 @@ func (h *PaymentHandler) GetMyOrders(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Paginated(c, orders, int64(total), page, pageSize)
+ response.Paginated(c, sanitizePaymentOrdersForResponse(orders), int64(total), page, pageSize)
}
// GetOrder returns a single order for the authenticated user.
@@ -282,7 +352,7 @@ func (h *PaymentHandler) GetOrder(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Success(c, order)
+ response.Success(c, sanitizePaymentOrderForResponse(order))
}
// CancelOrder cancels a pending order for the authenticated user.
@@ -354,6 +424,10 @@ type VerifyOrderRequest struct {
OutTradeNo string `json:"out_trade_no" binding:"required"`
}
+type ResolveOrderByResumeTokenRequest struct {
+ ResumeToken string `json:"resume_token" binding:"required"`
+}
+
// VerifyOrder actively queries the upstream payment provider to check
// if payment was made, and processes it if so.
// POST /api/v1/payment/orders/verify
@@ -374,23 +448,57 @@ func (h *PaymentHandler) VerifyOrder(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- response.Success(c, order)
+ response.Success(c, sanitizePaymentOrderForResponse(order))
}
// PublicOrderResult is the limited order info returned by the public verify endpoint.
// No user details are exposed — only payment status information.
type PublicOrderResult struct {
- ID int64 `json:"id"`
- OutTradeNo string `json:"out_trade_no"`
- Amount float64 `json:"amount"`
- PayAmount float64 `json:"pay_amount"`
- PaymentType string `json:"payment_type"`
- OrderType string `json:"order_type"`
- Status string `json:"status"`
+ ID int64 `json:"id"`
+ OutTradeNo string `json:"out_trade_no"`
+ Amount float64 `json:"amount"`
+ PayAmount float64 `json:"pay_amount"`
+ FeeRate float64 `json:"fee_rate"`
+ PaymentType string `json:"payment_type"`
+ OrderType string `json:"order_type"`
+ Status string `json:"status"`
+ CreatedAt time.Time `json:"created_at"`
+ ExpiresAt time.Time `json:"expires_at"`
+ PaidAt *time.Time `json:"paid_at,omitempty"`
+ CompletedAt *time.Time `json:"completed_at,omitempty"`
+ RefundAmount float64 `json:"refund_amount"`
+ RefundReason *string `json:"refund_reason,omitempty"`
+ RefundRequestedAt *time.Time `json:"refund_requested_at,omitempty"`
+ RefundRequestedBy *string `json:"refund_requested_by,omitempty"`
+ RefundRequestReason *string `json:"refund_request_reason,omitempty"`
+ PlanID *int64 `json:"plan_id,omitempty"`
}
-// VerifyOrderPublic verifies payment status without requiring authentication.
-// Returns limited order info (no user details) to prevent information leakage.
+func buildPublicOrderResult(order *dbent.PaymentOrder) PublicOrderResult {
+ return PublicOrderResult{
+ ID: order.ID,
+ OutTradeNo: order.OutTradeNo,
+ Amount: order.Amount,
+ PayAmount: order.PayAmount,
+ FeeRate: order.FeeRate,
+ PaymentType: order.PaymentType,
+ OrderType: order.OrderType,
+ Status: order.Status,
+ CreatedAt: order.CreatedAt,
+ ExpiresAt: order.ExpiresAt,
+ PaidAt: order.PaidAt,
+ CompletedAt: order.CompletedAt,
+ RefundAmount: order.RefundAmount,
+ RefundReason: order.RefundReason,
+ RefundRequestedAt: order.RefundRequestedAt,
+ RefundRequestedBy: order.RefundRequestedBy,
+ RefundRequestReason: order.RefundRequestReason,
+ PlanID: order.PlanID,
+ }
+}
+
+// VerifyOrderPublic keeps the legacy anonymous out_trade_no lookup available as
+// a compatibility path for older result pages and staggered deploys.
// POST /api/v1/payment/public/orders/verify
func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
var req VerifyOrderRequest
@@ -398,20 +506,30 @@ func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
+
order, err := h.paymentService.VerifyOrderPublic(c.Request.Context(), req.OutTradeNo)
if err != nil {
response.ErrorFrom(c, err)
return
}
- response.Success(c, PublicOrderResult{
- ID: order.ID,
- OutTradeNo: order.OutTradeNo,
- Amount: order.Amount,
- PayAmount: order.PayAmount,
- PaymentType: order.PaymentType,
- OrderType: order.OrderType,
- Status: order.Status,
- })
+ response.Success(c, buildPublicOrderResult(order))
+}
+
+// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token.
+// POST /api/v1/payment/public/orders/resolve
+func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) {
+ var req ResolveOrderByResumeTokenRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ order, err := h.paymentService.GetPublicOrderByResumeToken(c.Request.Context(), req.ResumeToken)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, buildPublicOrderResult(order))
}
// requireAuth extracts the authenticated subject from the context.
@@ -435,3 +553,27 @@ func isMobile(c *gin.Context) bool {
}
return false
}
+
+func sanitizePaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
+ if len(orders) == 0 {
+ return orders
+ }
+ out := make([]*dbent.PaymentOrder, 0, len(orders))
+ for _, order := range orders {
+ out = append(out, sanitizePaymentOrderForResponse(order))
+ }
+ return out
+}
+
+func sanitizePaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
+ if order == nil {
+ return nil
+ }
+ cloned := *order
+ cloned.ProviderSnapshot = nil
+ return &cloned
+}
+
+func isWeChatBrowser(c *gin.Context) bool {
+ return strings.Contains(strings.ToLower(c.GetHeader("User-Agent")), "micromessenger")
+}
diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go
new file mode 100644
index 00000000..377f432e
--- /dev/null
+++ b/backend/internal/handler/payment_handler_resume_test.go
@@ -0,0 +1,368 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func TestApplyWeChatPaymentResumeClaims(t *testing.T) {
+ t.Parallel()
+
+ req := CreateOrderRequest{
+ Amount: 0,
+ PaymentType: payment.TypeWxpay,
+ OrderType: payment.OrderTypeBalance,
+ }
+
+ err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeSubscription,
+ PlanID: 7,
+ })
+ if err != nil {
+ t.Fatalf("applyWeChatPaymentResumeClaims returned error: %v", err)
+ }
+ if req.OpenID != "openid-123" {
+ t.Fatalf("openid = %q, want %q", req.OpenID, "openid-123")
+ }
+ if req.Amount != 12.5 {
+ t.Fatalf("amount = %v, want 12.5", req.Amount)
+ }
+ if req.OrderType != payment.OrderTypeSubscription {
+ t.Fatalf("order_type = %q, want %q", req.OrderType, payment.OrderTypeSubscription)
+ }
+ if req.PlanID != 7 {
+ t.Fatalf("plan_id = %d, want 7", req.PlanID)
+ }
+}
+
+func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T) {
+ t.Parallel()
+
+ req := CreateOrderRequest{
+ PaymentType: payment.TypeAlipay,
+ }
+
+ err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeBalance,
+ })
+ if err == nil {
+ t.Fatal("applyWeChatPaymentResumeClaims should reject mismatched payment types")
+ }
+}
+
+func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
+ t.Parallel()
+
+ gin.SetMode(gin.TestMode)
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_verify?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ user, err := client.User.Create().
+ SetEmail("public-verify@example.com").
+ SetPasswordHash("hash").
+ SetUsername("public-verify-user").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(90.64).
+ SetFeeRate(0.03).
+ SetRechargeCode("PUBLIC-VERIFY").
+ SetOutTradeNo("legacy-order-no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-public-verify").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(service.OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/verify",
+ bytes.NewBufferString(`{"out_trade_no":"legacy-order-no"}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.VerifyOrderPublic(ctx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ ID int64 `json:"id"`
+ OutTradeNo string `json:"out_trade_no"`
+ Amount float64 `json:"amount"`
+ PayAmount float64 `json:"pay_amount"`
+ FeeRate float64 `json:"fee_rate"`
+ PaymentType string `json:"payment_type"`
+ OrderType string `json:"order_type"`
+ Status string `json:"status"`
+ RefundAmount float64 `json:"refund_amount"`
+ CreatedAt string `json:"created_at"`
+ ExpiresAt string `json:"expires_at"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, order.ID, resp.Data.ID)
+ require.Equal(t, "legacy-order-no", resp.Data.OutTradeNo)
+ require.Equal(t, 90.64, resp.Data.PayAmount)
+ require.Equal(t, 0.03, resp.Data.FeeRate)
+ require.Equal(t, payment.TypeAlipay, resp.Data.PaymentType)
+ require.Equal(t, payment.OrderTypeBalance, resp.Data.OrderType)
+ require.Equal(t, service.OrderStatusPending, resp.Data.Status)
+ require.Equal(t, 0.0, resp.Data.RefundAmount)
+ require.NotEmpty(t, resp.Data.CreatedAt)
+ require.NotEmpty(t, resp.Data.ExpiresAt)
+}
+
+func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ user, err := client.User.Create().
+ SetEmail("public-resolve@example.com").
+ SetPasswordHash("hash").
+ SetUsername("public-resolve-user").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(100).
+ SetPayAmount(103).
+ SetFeeRate(0.03).
+ SetRechargeCode("PUBLIC-RESOLVE").
+ SetOutTradeNo("resolve-order-no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-public-resolve").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(service.OrderStatusPaid).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/resolve",
+ bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.ResolveOrderPublicByResumeToken(ctx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, float64(order.ID), resp.Data["id"])
+ require.Equal(t, "resolve-order-no", resp.Data["out_trade_no"])
+ require.Equal(t, 100.0, resp.Data["amount"])
+ require.Equal(t, 103.0, resp.Data["pay_amount"])
+ require.Equal(t, 0.03, resp.Data["fee_rate"])
+ require.Equal(t, payment.TypeAlipay, resp.Data["payment_type"])
+ require.Equal(t, payment.OrderTypeBalance, resp.Data["order_type"])
+ require.Equal(t, service.OrderStatusPaid, resp.Data["status"])
+ require.Contains(t, resp.Data, "created_at")
+ require.Contains(t, resp.Data, "expires_at")
+ require.Contains(t, resp.Data, "refund_amount")
+}
+
+func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ user, err := client.User.Create().
+ SetEmail("public-resolve-mismatch@example.com").
+ SetPasswordHash("hash").
+ SetUsername("public-resolve-mismatch-user").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(100).
+ SetPayAmount(103).
+ SetFeeRate(0.03).
+ SetRechargeCode("PUBLIC-RESOLVE-MISMATCH").
+ SetOutTradeNo("resolve-order-mismatch-no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-public-resolve-mismatch").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(service.OrderStatusPaid).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID + 999,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/resolve",
+ bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.ResolveOrderPublicByResumeToken(ctx)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Reason string `json:"reason"`
+ Message string `json:"message"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Code)
+ require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason)
+}
+
+func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/verify",
+ bytes.NewBufferString(`{"out_trade_no":" "}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.VerifyOrderPublic(ctx)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Reason string `json:"reason"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Code)
+ require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason)
+}
diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go
index 8a83bfeb..9ae799fd 100644
--- a/backend/internal/handler/payment_webhook_handler.go
+++ b/backend/internal/handler/payment_webhook_handler.go
@@ -1,6 +1,9 @@
package handler
import (
+ "context"
+ "errors"
+ "fmt"
"io"
"log/slog"
"net/http"
@@ -77,9 +80,13 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts).
outTradeNo := extractOutTradeNo(rawBody, providerKey)
- provider, err := h.paymentService.GetWebhookProvider(c.Request.Context(), providerKey, outTradeNo)
+ providers, err := h.paymentService.GetWebhookProviders(c.Request.Context(), providerKey, outTradeNo)
if err != nil {
slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err)
+ if providerKey == payment.TypeWxpay {
+ c.String(http.StatusBadRequest, "verify failed")
+ return
+ }
writeSuccessResponse(c, providerKey)
return
}
@@ -89,7 +96,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
headers[strings.ToLower(k)] = c.GetHeader(k)
}
- notification, err := provider.VerifyNotification(c.Request.Context(), rawBody, headers)
+ resolvedProviderKey, notification, err := verifyNotificationWithProviders(c.Request.Context(), providers, rawBody, headers)
if err != nil {
truncatedBody := rawBody
if len(truncatedBody) > webhookLogTruncateLen {
@@ -103,24 +110,38 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// nil notification means irrelevant event (e.g. Stripe non-payment event); return success.
if notification == nil {
- writeSuccessResponse(c, providerKey)
+ writeSuccessResponse(c, resolvedProviderKey)
return
}
- if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, providerKey); err != nil {
- slog.Error("[Payment Webhook] handle notification failed", "provider", providerKey, "error", err)
+ if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, resolvedProviderKey); err != nil {
+ // Unknown order: ack with 2xx so the provider stops retrying. This
+ // guards against foreign environments whose webhook endpoints are
+ // (mis)configured to point at us — without a 2xx, the provider will
+ // retry for days and spam our error logs. We still emit a WARN so the
+ // event is discoverable in logs.
+ if errors.Is(err, service.ErrOrderNotFound) {
+ slog.Warn("[Payment Webhook] unknown order, acking to stop retries",
+ "provider", resolvedProviderKey,
+ "outTradeNo", notification.OrderID,
+ "tradeNo", notification.TradeNo,
+ )
+ writeSuccessResponse(c, resolvedProviderKey)
+ return
+ }
+ slog.Error("[Payment Webhook] handle notification failed", "provider", resolvedProviderKey, "error", err)
c.String(http.StatusInternalServerError, "handle failed")
return
}
- writeSuccessResponse(c, providerKey)
+ writeSuccessResponse(c, resolvedProviderKey)
}
// extractOutTradeNo parses the webhook body to find the out_trade_no.
// This allows looking up the correct provider instance before verification.
func extractOutTradeNo(rawBody, providerKey string) string {
switch providerKey {
- case payment.TypeEasyPay:
+ case payment.TypeEasyPay, payment.TypeAlipay:
values, err := url.ParseQuery(rawBody)
if err == nil {
return values.Get("out_trade_no")
@@ -131,6 +152,25 @@ func extractOutTradeNo(rawBody, providerKey string) string {
return ""
}
+func verifyNotificationWithProviders(ctx context.Context, providers []payment.Provider, rawBody string, headers map[string]string) (string, *payment.PaymentNotification, error) {
+ var lastErr error
+ for _, provider := range providers {
+ if provider == nil {
+ continue
+ }
+ notification, err := provider.VerifyNotification(ctx, rawBody, headers)
+ if err != nil {
+ lastErr = err
+ continue
+ }
+ return provider.ProviderKey(), notification, nil
+ }
+ if lastErr != nil {
+ return "", nil, lastErr
+ }
+ return "", nil, fmt.Errorf("no webhook provider could verify notification")
+}
+
// wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook.
type wxpaySuccessResponse struct {
Code string `json:"code"`
diff --git a/backend/internal/handler/payment_webhook_handler_test.go b/backend/internal/handler/payment_webhook_handler_test.go
index bdef1766..7551fc83 100644
--- a/backend/internal/handler/payment_webhook_handler_test.go
+++ b/backend/internal/handler/payment_webhook_handler_test.go
@@ -3,11 +3,16 @@
package handler
import (
+ "context"
"encoding/json"
+ "errors"
+ "fmt"
"net/http"
"net/http/httptest"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -88,6 +93,43 @@ func TestWriteSuccessResponse(t *testing.T) {
}
}
+// TestUnknownOrderWebhookAcksWithSuccess exercises the response contract that
+// handleNotify relies on when HandlePaymentNotification returns ErrOrderNotFound:
+// we still need to emit the provider-specific 2xx so the provider stops
+// retrying. We can't easily drive handleNotify end-to-end without mocking the
+// concrete *service.PaymentService, so this test locks down the two ingredients
+// the fix depends on:
+// 1. errors.Is recognises the sentinel through fmt.Errorf %w wrapping (which
+// is how service layer wraps it with the out_trade_no context).
+// 2. writeSuccessResponse produces the provider-specific body for Stripe
+// (empty 200) — matching what handleNotify calls on the ack path.
+//
+// If either contract breaks, the Stripe "unknown order → 500 loop" regresses.
+func TestUnknownOrderWebhookAcksWithSuccess(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ // 1) Sentinel recognition through wrapping.
+ wrapped := fmt.Errorf("%w: out_trade_no=sub2_missing_42", service.ErrOrderNotFound)
+ require.True(t, errors.Is(wrapped, service.ErrOrderNotFound),
+ "handleNotify uses errors.Is on the wrapped service error; regression here "+
+ "would mean unknown-order webhooks go back to returning 500 and looping forever")
+
+ // A distinct error must NOT match — otherwise a DB failure would be silently
+ // swallowed as an ack.
+ other := errors.New("lookup order failed: connection refused")
+ require.False(t, errors.Is(other, service.ErrOrderNotFound))
+
+ // 2) Provider-specific success body is what handleNotify emits on the
+ // ack path. Asserted again here because this is the shape Stripe expects
+ // to consider the webhook acknowledged.
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ writeSuccessResponse(c, payment.TypeStripe)
+ require.Equal(t, http.StatusOK, w.Code,
+ "Stripe requires 2xx to stop retrying; anything else restarts the retry loop")
+ require.Empty(t, w.Body.String(), "Stripe expects an empty body on the ack path")
+}
+
func TestWebhookConstants(t *testing.T) {
t.Run("maxWebhookBodySize is 1MB", func(t *testing.T) {
assert.Equal(t, int64(1<<20), int64(maxWebhookBodySize))
@@ -97,3 +139,104 @@ func TestWebhookConstants(t *testing.T) {
assert.Equal(t, 200, webhookLogTruncateLen)
})
}
+
+func TestExtractOutTradeNo(t *testing.T) {
+ tests := []struct {
+ name string
+ providerKey string
+ rawBody string
+ want string
+ }{
+ {
+ name: "easypay query payload",
+ providerKey: "easypay",
+ rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS",
+ want: "sub2_123",
+ },
+ {
+ name: "alipay query payload",
+ providerKey: "alipay",
+ rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456",
+ want: "sub2_456",
+ },
+ {
+ name: "unknown provider",
+ providerKey: "wxpay",
+ rawBody: "{}",
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey))
+ })
+ }
+}
+
+func TestVerifyNotificationWithProvidersReturnsMatchedProvider(t *testing.T) {
+ firstErr := errors.New("wrong provider")
+ providers := []payment.Provider{
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ verifyErr: firstErr,
+ },
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ notification: &payment.PaymentNotification{
+ OrderID: "sub2_42",
+ TradeNo: "trade-42",
+ Status: payment.NotificationStatusSuccess,
+ },
+ },
+ }
+
+ providerKey, notification, err := verifyNotificationWithProviders(context.Background(), providers, "{}", map[string]string{"wechatpay-signature": "sig"})
+ require.NoError(t, err)
+ require.Equal(t, payment.TypeWxpay, providerKey)
+ require.NotNil(t, notification)
+ require.Equal(t, "sub2_42", notification.OrderID)
+}
+
+func TestVerifyNotificationWithProvidersFailsWhenAllProvidersReject(t *testing.T) {
+ providers := []payment.Provider{
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ verifyErr: errors.New("verify failed a"),
+ },
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ verifyErr: errors.New("verify failed b"),
+ },
+ }
+
+ _, _, err := verifyNotificationWithProviders(context.Background(), providers, "{}", nil)
+ require.Error(t, err)
+}
+
+type webhookHandlerProviderStub struct {
+ key string
+ notification *payment.PaymentNotification
+ verifyErr error
+}
+
+func (p webhookHandlerProviderStub) Name() string { return p.key }
+func (p webhookHandlerProviderStub) ProviderKey() string { return p.key }
+func (p webhookHandlerProviderStub) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.PaymentType(p.key)}
+}
+func (p webhookHandlerProviderStub) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookHandlerProviderStub) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookHandlerProviderStub) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ if p.verifyErr != nil {
+ return nil, p.verifyErr
+ }
+ return p.notification, nil
+}
+func (p webhookHandlerProviderStub) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 1717b7a1..22f2aa15 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
@@ -56,6 +57,10 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
+ WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
+ WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,
+ WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled,
+ WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
BackendModeEnabled: settings.BackendModeEnabled,
@@ -65,5 +70,12 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
+
+ ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
+ ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
+
+ AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
+
+ AffiliateEnabled: settings.AffiliateEnabled,
})
}
diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go
new file mode 100644
index 00000000..45d66f8e
--- /dev/null
+++ b/backend/internal/handler/setting_handler_public_test.go
@@ -0,0 +1,122 @@
+//go:build unit
+
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type settingHandlerPublicRepoStub struct {
+ values map[string]string
+}
+
+func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &settingHandlerPublicRepoStub{
+ values: map[string]string{
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version")
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil)
+
+ h.GetPublicSettings(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.True(t, resp.Data.ForceEmailOnThirdPartySignup)
+}
+
+func TestSettingHandler_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ h := NewSettingHandler(service.NewSettingService(&settingHandlerPublicRepoStub{
+ values: map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: "wx-mp-app",
+ service.SettingKeyWeChatConnectAppSecret: "wx-mp-secret",
+ service.SettingKeyWeChatConnectMode: "mp",
+ service.SettingKeyWeChatConnectScopes: "snsapi_base",
+ service.SettingKeyWeChatConnectOpenEnabled: "true",
+ service.SettingKeyWeChatConnectMPEnabled: "true",
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ },
+ }, &config.Config{}), "test-version")
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil)
+
+ h.GetPublicSettings(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
+ WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
+ WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.True(t, resp.Data.WeChatOAuthEnabled)
+ require.True(t, resp.Data.WeChatOAuthOpenEnabled)
+ require.True(t, resp.Data.WeChatOAuthMPEnabled)
+}
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 2535ea5e..3f6ed8c2 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -1,6 +1,9 @@
package handler
import (
+ "context"
+ "strings"
+
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -11,17 +14,27 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
- userService *service.UserService
- emailService *service.EmailService
- emailCache service.EmailCache
+ userService *service.UserService
+ authService *service.AuthService
+ emailService *service.EmailService
+ emailCache service.EmailCache
+ affiliateService *service.AffiliateService
}
// NewUserHandler creates a new UserHandler
-func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler {
+func NewUserHandler(
+ userService *service.UserService,
+ authService *service.AuthService,
+ emailService *service.EmailService,
+ emailCache service.EmailCache,
+ affiliateService *service.AffiliateService,
+) *UserHandler {
return &UserHandler{
- userService: userService,
- emailService: emailService,
- emailCache: emailCache,
+ userService: userService,
+ authService: authService,
+ emailService: emailService,
+ emailCache: emailCache,
+ affiliateService: affiliateService,
}
}
@@ -34,10 +47,33 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
Username *string `json:"username"`
+ AvatarURL *string `json:"avatar_url"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
+type userProfileResponse struct {
+ dto.User
+ AvatarURL string `json:"avatar_url,omitempty"`
+ AvatarSource *userProfileSourceContext `json:"avatar_source,omitempty"`
+ UsernameSource *userProfileSourceContext `json:"username_source,omitempty"`
+ DisplayNameSource *userProfileSourceContext `json:"display_name_source,omitempty"`
+ NicknameSource *userProfileSourceContext `json:"nickname_source,omitempty"`
+ ProfileSources map[string]*userProfileSourceContext `json:"profile_sources,omitempty"`
+ Identities service.UserIdentitySummarySet `json:"identities"`
+ AuthBindings map[string]service.UserIdentitySummary `json:"auth_bindings"`
+ IdentityBindings map[string]service.UserIdentitySummary `json:"identity_bindings"`
+ EmailBound bool `json:"email_bound"`
+ LinuxDoBound bool `json:"linuxdo_bound"`
+ OIDCBound bool `json:"oidc_bound"`
+ WeChatBound bool `json:"wechat_bound"`
+}
+
+type userProfileSourceContext struct {
+ Provider string `json:"provider,omitempty"`
+ Source string `json:"source,omitempty"`
+}
+
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
@@ -47,13 +83,19 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return
}
- userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
- response.Success(c, dto.UserFromService(userData))
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, userData)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
}
// ChangePassword handles changing user password
@@ -101,6 +143,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq := service.UpdateProfileRequest{
Username: req.Username,
+ AvatarURL: req.AvatarURL,
BalanceNotifyEnabled: req.BalanceNotifyEnabled,
BalanceNotifyThreshold: req.BalanceNotifyThreshold,
}
@@ -110,7 +153,193 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+// GetAffiliate returns the current user's affiliate details.
+// GET /api/v1/user/aff
+func (h *UserHandler) GetAffiliate(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ detail, err := h.affiliateService.GetAffiliateDetail(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, detail)
+}
+
+// TransferAffiliateQuota transfers all available affiliate quota into current balance.
+// POST /api/v1/user/aff/transfer
+func (h *UserHandler) TransferAffiliateQuota(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ transferred, balance, err := h.affiliateService.TransferAffiliateQuota(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{
+ "transferred_quota": transferred,
+ "balance": balance,
+ })
+}
+
+type StartIdentityBindingRequest struct {
+ Provider string `json:"provider" binding:"required"`
+ RedirectTo string `json:"redirect_to"`
+}
+
+type BindEmailIdentityRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ VerifyCode string `json:"verify_code" binding:"required"`
+ Password string `json:"password" binding:"required"`
+}
+
+type SendEmailBindingCodeRequest struct {
+ Email string `json:"email" binding:"required,email"`
+}
+
+// StartIdentityBinding returns the backend authorize URL for starting a third-party identity bind flow.
+// POST /api/v1/user/auth-identities/bind/start
+func (h *UserHandler) StartIdentityBinding(c *gin.Context) {
+ if _, ok := middleware2.GetAuthSubjectFromContext(c); !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req StartIdentityBindingRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ result, err := h.userService.PrepareIdentityBindingStart(c.Request.Context(), service.StartUserIdentityBindingRequest{
+ Provider: req.Provider,
+ RedirectTo: req.RedirectTo,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, result)
+}
+
+// BindEmailIdentity verifies and binds a local email identity for the current user.
+// POST /api/v1/user/account-bindings/email
+func (h *UserHandler) BindEmailIdentity(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+ if h.authService == nil {
+ response.InternalError(c, "Auth service not configured")
+ return
+ }
+
+ var req BindEmailIdentityRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ updatedUser, err := h.authService.BindEmailIdentity(
+ c.Request.Context(),
+ subject.UserID,
+ req.Email,
+ req.VerifyCode,
+ req.Password,
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+// UnbindIdentity removes a third-party sign-in provider from the current user.
+// DELETE /api/v1/user/account-bindings/:provider
+func (h *UserHandler) UnbindIdentity(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ updatedUser, unbound, err := h.userService.UnbindUserAuthProviderWithResult(
+ c.Request.Context(),
+ subject.UserID,
+ c.Param("provider"),
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if unbound && h.authService != nil {
+ if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+// SendEmailBindingCode sends a verification code for the current user's email binding flow.
+// POST /api/v1/user/account-bindings/email/send-code
+func (h *UserHandler) SendEmailBindingCode(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+ if h.authService == nil {
+ response.InternalError(c, "Auth service not configured")
+ return
+ }
+
+ var req SendEmailBindingCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "Verification code sent successfully"})
}
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
@@ -176,7 +405,13 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
}
// RemoveNotifyEmailRequest represents the request to remove a notify email
@@ -212,7 +447,13 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
}
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
@@ -248,5 +489,117 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+func (h *UserHandler) buildUserProfileResponse(ctx context.Context, userID int64, user *service.User) (userProfileResponse, error) {
+ identities, err := h.userService.GetProfileIdentitySummaries(ctx, userID, user)
+ if err != nil {
+ return userProfileResponse{}, err
+ }
+ return userProfileResponseFromService(user, identities), nil
+}
+
+func userProfileResponseFromService(user *service.User, identities service.UserIdentitySummarySet) userProfileResponse {
+ base := dto.UserFromService(user)
+ if base == nil {
+ return userProfileResponse{}
+ }
+ bindings := userProfileBindingMap(identities)
+ profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities)
+ return userProfileResponse{
+ User: *base,
+ AvatarURL: user.AvatarURL,
+ AvatarSource: avatarSource,
+ UsernameSource: usernameSource,
+ DisplayNameSource: usernameSource,
+ NicknameSource: usernameSource,
+ ProfileSources: profileSources,
+ Identities: identities,
+ AuthBindings: bindings,
+ IdentityBindings: bindings,
+ EmailBound: identities.Email.Bound,
+ LinuxDoBound: identities.LinuxDo.Bound,
+ OIDCBound: identities.OIDC.Bound,
+ WeChatBound: identities.WeChat.Bound,
+ }
+}
+
+func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary {
+ return map[string]service.UserIdentitySummary{
+ "email": identities.Email,
+ "linuxdo": identities.LinuxDo,
+ "oidc": identities.OIDC,
+ "wechat": identities.WeChat,
+ }
+}
+
+func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) (
+ map[string]*userProfileSourceContext,
+ *userProfileSourceContext,
+ *userProfileSourceContext,
+) {
+ if user == nil {
+ return nil, nil, nil
+ }
+
+ thirdParty := thirdPartyIdentityProviders(identities)
+ var avatarSource *userProfileSourceContext
+ avatarValue := strings.TrimSpace(user.AvatarURL)
+ for _, summary := range thirdParty {
+ if avatarValue != "" && avatarValue == strings.TrimSpace(summary.AvatarURL) {
+ avatarSource = buildUserProfileSourceContext(summary.Provider)
+ break
+ }
+ }
+
+ usernameValue := strings.TrimSpace(user.Username)
+ var usernameSource *userProfileSourceContext
+ for _, summary := range thirdParty {
+ if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) {
+ usernameSource = buildUserProfileSourceContext(summary.Provider)
+ break
+ }
+ }
+
+ profileSources := map[string]*userProfileSourceContext{}
+ if avatarSource != nil {
+ profileSources["avatar"] = avatarSource
+ }
+ if usernameSource != nil {
+ profileSources["username"] = usernameSource
+ profileSources["display_name"] = usernameSource
+ profileSources["nickname"] = usernameSource
+ }
+ if len(profileSources) == 0 {
+ return nil, avatarSource, usernameSource
+ }
+ return profileSources, avatarSource, usernameSource
+}
+
+func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
+ out := make([]service.UserIdentitySummary, 0, 3)
+ for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
+ if summary.Bound {
+ out = append(out, summary)
+ }
+ }
+ return out
+}
+
+func buildUserProfileSourceContext(provider string) *userProfileSourceContext {
+ provider = strings.TrimSpace(provider)
+ if provider == "" {
+ return nil
+ }
+ return &userProfileSourceContext{
+ Provider: provider,
+ Source: provider,
+ }
}
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
new file mode 100644
index 00000000..8a864b51
--- /dev/null
+++ b/backend/internal/handler/user_handler_test.go
@@ -0,0 +1,783 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type userHandlerRepoStub struct {
+ user *service.User
+ identities []service.UserAuthIdentityRecord
+ unbound []string
+}
+
+func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil }
+func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error {
+ cloned := *user
+ s.user = &cloned
+ return nil
+}
+func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
+ if s.user == nil || s.user.AvatarURL == "" {
+ return nil, nil
+ }
+ return &service.UserAvatar{
+ StorageProvider: s.user.AvatarSource,
+ URL: s.user.AvatarURL,
+ ContentType: s.user.AvatarMIME,
+ ByteSize: s.user.AvatarByteSize,
+ SHA256: s.user.AvatarSHA256,
+ }, nil
+}
+func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ s.user.AvatarURL = input.URL
+ s.user.AvatarSource = input.StorageProvider
+ s.user.AvatarMIME = input.ContentType
+ s.user.AvatarByteSize = input.ByteSize
+ s.user.AvatarSHA256 = input.SHA256
+ return &service.UserAvatar{
+ StorageProvider: input.StorageProvider,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+}
+func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ s.user.AvatarURL = ""
+ s.user.AvatarSource = ""
+ s.user.AvatarMIME = ""
+ s.user.AvatarByteSize = 0
+ s.user.AvatarSHA256 = ""
+ return nil
+}
+func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
+func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+func (s *userHandlerRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+func (s *userHandlerRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+func (s *userHandlerRepoStub) UpdateUserLastActiveAt(_ context.Context, _ int64, activeAt time.Time) error {
+ if s.user != nil {
+ s.user.LastActiveAt = &activeAt
+ }
+ return nil
+}
+func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) {
+ out := make([]service.UserAuthIdentityRecord, len(s.identities))
+ copy(out, s.identities)
+ return out, nil
+}
+func (s *userHandlerRepoStub) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error {
+ s.unbound = append(s.unbound, provider)
+ filtered := s.identities[:0]
+ for _, identity := range s.identities {
+ if identity.ProviderType == provider {
+ continue
+ }
+ filtered = append(filtered, identity)
+ }
+ s.identities = append([]service.UserAuthIdentityRecord(nil), filtered...)
+ return nil
+}
+
+func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "handler-avatar@example.com",
+ Username: "handler-avatar",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
+
+ body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.UpdateProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ AvatarURL string `json:"avatar_url"`
+ Username string `json:"username"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL)
+ require.Equal(t, "handler-avatar", resp.Data.Username)
+}
+
+func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-123456",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ },
+ },
+ {
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example.com",
+ ProviderSubject: "oidc-user-abc",
+ Metadata: map[string]any{
+ "suggested_display_name": "OIDC Display",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.GetProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Identities struct {
+ Email struct {
+ Bound bool `json:"bound"`
+ BoundCount int `json:"bound_count"`
+ DisplayName string `json:"display_name"`
+ } `json:"email"`
+ LinuxDo struct {
+ Bound bool `json:"bound"`
+ BoundCount int `json:"bound_count"`
+ DisplayName string `json:"display_name"`
+ ProviderKey string `json:"provider_key"`
+ } `json:"linuxdo"`
+ OIDC struct {
+ Bound bool `json:"bound"`
+ DisplayName string `json:"display_name"`
+ ProviderKey string `json:"provider_key"`
+ } `json:"oidc"`
+ WeChat struct {
+ Bound bool `json:"bound"`
+ CanBind bool `json:"can_bind"`
+ BindStartPath string `json:"bind_start_path"`
+ } `json:"wechat"`
+ } `json:"identities"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.True(t, resp.Data.Identities.Email.Bound)
+ require.Equal(t, 1, resp.Data.Identities.Email.BoundCount)
+ require.Equal(t, "identity@example.com", resp.Data.Identities.Email.DisplayName)
+ require.True(t, resp.Data.Identities.LinuxDo.Bound)
+ require.Equal(t, 1, resp.Data.Identities.LinuxDo.BoundCount)
+ require.Equal(t, "linuxdo-handle", resp.Data.Identities.LinuxDo.DisplayName)
+ require.Equal(t, "linuxdo", resp.Data.Identities.LinuxDo.ProviderKey)
+ require.True(t, resp.Data.Identities.OIDC.Bound)
+ require.Equal(t, "OIDC Display", resp.Data.Identities.OIDC.DisplayName)
+ require.Equal(t, "https://issuer.example.com", resp.Data.Identities.OIDC.ProviderKey)
+ require.False(t, resp.Data.Identities.WeChat.Bound)
+ require.True(t, resp.Data.Identities.WeChat.CanBind)
+ require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/bind/start")
+}
+
+func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 21,
+ Email: "legacy-profile@example.com",
+ Username: "linuxdo-handle",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/linuxdo.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-21",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ "avatar_url": "https://cdn.example.com/linuxdo.png",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21})
+
+ handler.GetProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, true, resp.Data["email_bound"])
+ require.Equal(t, true, resp.Data["linuxdo_bound"])
+ require.Equal(t, false, resp.Data["oidc_bound"])
+ require.Equal(t, false, resp.Data["wechat_bound"])
+ require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
+
+ avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", avatarSource["provider"])
+ require.Equal(t, "linuxdo", avatarSource["source"])
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, linuxdoBinding["bound"])
+ require.Equal(t, "linuxdo", linuxdoBinding["provider"])
+
+ identityBindings, ok := resp.Data["identity_bindings"].(map[string]any)
+ require.True(t, ok)
+ emailBinding, ok := identityBindings["email"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, emailBinding["bound"])
+ require.Equal(t, "profile.authBindings.notes.emailManagedFromProfile", emailBinding["note_key"])
+
+ linuxdoCompatBinding, ok := identityBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "profile.authBindings.notes.canUnbind", linuxdoCompatBinding["note_key"])
+
+ profileSources, ok := resp.Data["profile_sources"].(map[string]any)
+ require.True(t, ok)
+ usernameSource, ok := profileSources["username"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", usernameSource["provider"])
+ require.Equal(t, "linuxdo", usernameSource["source"])
+}
+
+func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIdentityMetadata(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 22,
+ Email: "edited-profile@example.com",
+ Username: "custom-name",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/custom.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-22",
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ "avatar_url": "https://cdn.example.com/linuxdo.png",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 22})
+
+ handler.GetProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.NotContains(t, resp.Data, "avatar_source")
+ require.NotContains(t, resp.Data, "username_source")
+ require.NotContains(t, resp.Data, "profile_sources")
+}
+
+type userHandlerEmailCacheStub struct {
+ data *service.VerificationCodeData
+}
+
+type userHandlerRefreshTokenCacheStub struct {
+ revokedUserIDs []int64
+}
+
+func (s *userHandlerRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *userHandlerRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error {
+ s.revokedUserIDs = append(s.revokedUserIDs, userID)
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
+func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return s.data, nil
+}
+
+func (s *userHandlerEmailCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) DeleteVerificationCode(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *userHandlerEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *userHandlerEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *userHandlerEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *userHandlerEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
+
+func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "legacy-user" + service.LinuxDoConnectSyntheticEmailDomain,
+ Username: "legacy-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ }
+ emailCache := &userHandlerEmailCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ emailService := service.NewEmailService(nil, emailCache)
+ authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
+
+ body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Params = gin.Params{{Key: "provider", Value: "email"}}
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.BindEmailIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Email string `json:"email"`
+ EmailBound bool `json:"email_bound"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "new@example.com", resp.Data.Email)
+ require.True(t, resp.Data.EmailBound)
+}
+
+func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 21,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "identity@example.com",
+ },
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-21",
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21})
+ c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
+
+ handler.UnbindIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, []string{"linuxdo"}, repo.unbound)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, false, linuxdoBinding["bound"])
+}
+
+func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigured(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 23,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 4,
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "identity@example.com",
+ },
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-23",
+ },
+ },
+ }
+ refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 23})
+ c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
+
+ handler.UnbindIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, []int64{23}, refreshTokenCache.revokedUserIDs)
+ require.Equal(t, int64(5), repo.user.TokenVersion)
+}
+
+func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 24,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 4,
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "identity@example.com",
+ },
+ },
+ }
+ refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 24})
+ c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
+
+ handler.UnbindIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Empty(t, repo.unbound)
+ require.Empty(t, refreshTokenCache.revokedUserIDs)
+ require.Equal(t, int64(4), repo.user.TokenVersion)
+}
+
+func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ user := &service.User{
+ ID: 11,
+ Email: "current@example.com",
+ Username: "bound-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, user.SetPassword("current-password"))
+
+ repo := &userHandlerRepoStub{user: user}
+ emailCache := &userHandlerEmailCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ emailService := service.NewEmailService(nil, emailCache)
+ authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
+
+ body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.BindEmailIdentity(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Reason string `json:"reason"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Code)
+ require.Equal(t, "PASSWORD_INCORRECT", resp.Reason)
+ require.Equal(t, "current password is incorrect", resp.Message)
+ require.Equal(t, "current@example.com", repo.user.Email)
+}
+
+func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
+
+ body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/auth-identities/bind/start", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.StartIdentityBinding(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Provider string `json:"provider"`
+ AuthorizeURL string `json:"authorize_url"`
+ Method string `json:"method"`
+ UseBrowserRedirect bool `json:"use_browser_redirect"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "wechat", resp.Data.Provider)
+ require.Equal(t, "GET", resp.Data.Method)
+ require.True(t, resp.Data.UseBrowserRedirect)
+ require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/bind/start")
+ require.Contains(t, resp.Data.AuthorizeURL, "intent=bind_current_user")
+ require.Contains(t, resp.Data.AuthorizeURL, "redirect=%2Fsettings%2Fprofile")
+}
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index 4b54d41a..a8725875 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -34,35 +34,41 @@ func ProvideAdminHandlers(
apiKeyHandler *admin.AdminAPIKeyHandler,
scheduledTestHandler *admin.ScheduledTestHandler,
channelHandler *admin.ChannelHandler,
+ channelMonitorHandler *admin.ChannelMonitorHandler,
+ channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
paymentHandler *admin.PaymentHandler,
+ affiliateHandler *admin.AffiliateHandler,
) *AdminHandlers {
return &AdminHandlers{
- Dashboard: dashboardHandler,
- User: userHandler,
- Group: groupHandler,
- Account: accountHandler,
- Announcement: announcementHandler,
- DataManagement: dataManagementHandler,
- Backup: backupHandler,
- OAuth: oauthHandler,
- OpenAIOAuth: openaiOAuthHandler,
- GeminiOAuth: geminiOAuthHandler,
- AntigravityOAuth: antigravityOAuthHandler,
- Proxy: proxyHandler,
- Redeem: redeemHandler,
- Promo: promoHandler,
- Setting: settingHandler,
- Ops: opsHandler,
- System: systemHandler,
- Subscription: subscriptionHandler,
- Usage: usageHandler,
- UserAttribute: userAttributeHandler,
- ErrorPassthrough: errorPassthroughHandler,
- TLSFingerprintProfile: tlsFingerprintProfileHandler,
- APIKey: apiKeyHandler,
- ScheduledTest: scheduledTestHandler,
- Channel: channelHandler,
- Payment: paymentHandler,
+ Dashboard: dashboardHandler,
+ User: userHandler,
+ Group: groupHandler,
+ Account: accountHandler,
+ Announcement: announcementHandler,
+ DataManagement: dataManagementHandler,
+ Backup: backupHandler,
+ OAuth: oauthHandler,
+ OpenAIOAuth: openaiOAuthHandler,
+ GeminiOAuth: geminiOAuthHandler,
+ AntigravityOAuth: antigravityOAuthHandler,
+ Proxy: proxyHandler,
+ Redeem: redeemHandler,
+ Promo: promoHandler,
+ Setting: settingHandler,
+ Ops: opsHandler,
+ System: systemHandler,
+ Subscription: subscriptionHandler,
+ Usage: usageHandler,
+ UserAttribute: userAttributeHandler,
+ ErrorPassthrough: errorPassthroughHandler,
+ TLSFingerprintProfile: tlsFingerprintProfileHandler,
+ APIKey: apiKeyHandler,
+ ScheduledTest: scheduledTestHandler,
+ Channel: channelHandler,
+ ChannelMonitor: channelMonitorHandler,
+ ChannelMonitorTemplate: channelMonitorTemplateHandler,
+ Payment: paymentHandler,
+ Affiliate: affiliateHandler,
}
}
@@ -85,6 +91,7 @@ func ProvideHandlers(
redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler,
announcementHandler *AnnouncementHandler,
+ channelMonitorUserHandler *ChannelMonitorUserHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
@@ -92,24 +99,27 @@ func ProvideHandlers(
totpHandler *TotpHandler,
paymentHandler *PaymentHandler,
paymentWebhookHandler *PaymentWebhookHandler,
+ availableChannelHandler *AvailableChannelHandler,
_ *service.IdempotencyCoordinator,
_ *service.IdempotencyCleanupService,
) *Handlers {
return &Handlers{
- Auth: authHandler,
- User: userHandler,
- APIKey: apiKeyHandler,
- Usage: usageHandler,
- Redeem: redeemHandler,
- Subscription: subscriptionHandler,
- Announcement: announcementHandler,
- Admin: adminHandlers,
- Gateway: gatewayHandler,
- OpenAIGateway: openaiGatewayHandler,
- Setting: settingHandler,
- Totp: totpHandler,
- Payment: paymentHandler,
- PaymentWebhook: paymentWebhookHandler,
+ Auth: authHandler,
+ User: userHandler,
+ APIKey: apiKeyHandler,
+ Usage: usageHandler,
+ Redeem: redeemHandler,
+ Subscription: subscriptionHandler,
+ Announcement: announcementHandler,
+ ChannelMonitor: channelMonitorUserHandler,
+ Admin: adminHandlers,
+ Gateway: gatewayHandler,
+ OpenAIGateway: openaiGatewayHandler,
+ Setting: settingHandler,
+ Totp: totpHandler,
+ Payment: paymentHandler,
+ PaymentWebhook: paymentWebhookHandler,
+ AvailableChannel: availableChannelHandler,
}
}
@@ -123,12 +133,14 @@ var ProviderSet = wire.NewSet(
NewRedeemHandler,
NewSubscriptionHandler,
NewAnnouncementHandler,
+ NewChannelMonitorUserHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
NewTotpHandler,
ProvideSettingHandler,
NewPaymentHandler,
NewPaymentWebhookHandler,
+ NewAvailableChannelHandler,
// Admin handlers
admin.NewDashboardHandler,
@@ -156,7 +168,10 @@ var ProviderSet = wire.NewSet(
admin.NewAdminAPIKeyHandler,
admin.NewScheduledTestHandler,
admin.NewChannelHandler,
+ admin.NewChannelMonitorHandler,
+ admin.NewChannelMonitorRequestTemplateHandler,
admin.NewPaymentHandler,
+ admin.NewAffiliateHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go
index e39e957f..0581469d 100644
--- a/backend/internal/payment/crypto.go
+++ b/backend/internal/payment/crypto.go
@@ -10,12 +10,20 @@ import (
"strings"
)
+// AES256KeySize is the required key length (in bytes) for AES-256-GCM.
+const AES256KeySize = 32
+
// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
// The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
// matching the Node.js crypto.ts format for cross-compatibility.
+//
+// Deprecated: payment provider configs are now stored as plaintext JSON.
+// This function is kept only for seeding legacy ciphertext in tests and for
+// the transitional Decrypt fallback. Scheduled for removal after all live
+// deployments complete migration by re-saving their configs.
func Encrypt(plaintext string, key []byte) (string, error) {
- if len(key) != 32 {
- return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
+ if len(key) != AES256KeySize {
+ return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
}
block, err := aes.NewCipher(key)
@@ -51,9 +59,14 @@ func Encrypt(plaintext string, key []byte) (string, error) {
// Decrypt decrypts a ciphertext string produced by Encrypt.
// The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
+//
+// Deprecated: payment provider configs are now stored as plaintext JSON.
+// This function remains only as a read-path fallback for pre-migration
+// ciphertext records. Scheduled for removal once all deployments re-save
+// their provider configs through the admin UI.
func Decrypt(ciphertext string, key []byte) (string, error) {
- if len(key) != 32 {
- return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key))
+ if len(key) != AES256KeySize {
+ return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
}
parts := strings.SplitN(ciphertext, ":", 3)
diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go
index f0353173..41fd2c50 100644
--- a/backend/internal/payment/load_balancer.go
+++ b/backend/internal/payment/load_balancer.go
@@ -45,11 +45,31 @@ type DefaultLoadBalancer struct {
counter atomic.Uint64
}
+type contextKey string
+
+const wxpayJSAPIAppIDContextKey contextKey = "payment.wxpay.jsapi_app_id"
+
// NewDefaultLoadBalancer creates a new load balancer.
func NewDefaultLoadBalancer(db *dbent.Client, encryptionKey []byte) *DefaultLoadBalancer {
return &DefaultLoadBalancer{db: db, encryptionKey: encryptionKey}
}
+func WithWxpayJSAPIAppID(ctx context.Context, appID string) context.Context {
+ appID = strings.TrimSpace(appID)
+ if appID == "" {
+ return ctx
+ }
+ return context.WithValue(ctx, wxpayJSAPIAppIDContextKey, appID)
+}
+
+func wxpayJSAPIAppIDFromContext(ctx context.Context) string {
+ if ctx == nil {
+ return ""
+ }
+ appID, _ := ctx.Value(wxpayJSAPIAppIDContextKey).(string)
+ return strings.TrimSpace(appID)
+}
+
// instanceCandidate pairs an instance with its pre-fetched daily usage.
type instanceCandidate struct {
inst *dbent.PaymentProviderInstance
@@ -116,6 +136,7 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
}
var matched []*dbent.PaymentProviderInstance
+ expectedWxpayJSAPIAppID := wxpayJSAPIAppIDFromContext(ctx)
for _, inst := range instances {
// Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
// not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
@@ -124,6 +145,16 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
matched = append(matched, inst)
}
} else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
+ if expectedWxpayJSAPIAppID != "" && normalizeVisibleMethodSupportType(paymentType) == TypeWxpay && inst.ProviderKey == TypeWxpay {
+ config, cfgErr := lb.decryptConfig(inst.Config)
+ if cfgErr != nil {
+ slog.Warn("skip wxpay instance with unreadable config during jsapi filtering", "instance_id", inst.ID, "error", cfgErr)
+ continue
+ }
+ if resolveWxpayJSAPIAppID(config) != expectedWxpayJSAPIAppID {
+ continue
+ }
+ }
matched = append(matched, inst)
}
}
@@ -231,6 +262,11 @@ func getInstanceChannelLimits(inst *dbent.PaymentProviderInstance, paymentType P
if cl, ok := limits[lookupKey]; ok {
return cl
}
+ if aliasKey := legacyVisibleMethodAlias(lookupKey); aliasKey != "" {
+ if cl, ok := limits[aliasKey]; ok {
+ return cl
+ }
+ }
return ChannelLimits{}
}
@@ -261,6 +297,9 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
if err != nil {
return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err)
}
+ if config == nil {
+ config = map[string]string{}
+ }
if selected.PaymentMode != "" {
config["paymentMode"] = selected.PaymentMode
@@ -275,16 +314,36 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
}, nil
}
-func (lb *DefaultLoadBalancer) decryptConfig(encrypted string) (map[string]string, error) {
- plaintext, err := Decrypt(encrypted, lb.encryptionKey)
- if err != nil {
- return nil, err
+// decryptConfig parses a stored provider config.
+// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext.
+// Unreadable values (legacy ciphertext without a valid key, or malformed data)
+// are treated as empty so the service keeps running while the admin re-enters
+// the config via the UI.
+//
+// TODO(deprecated-legacy-ciphertext): The AES fallback branch below is a
+// transitional compatibility shim for pre-plaintext records. Remove it (and
+// the encryptionKey field + the Decrypt import) after a few releases once all
+// live deployments have re-saved their provider configs through the UI.
+func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) {
+ if stored == "" {
+ return nil, nil
}
var config map[string]string
- if err := json.Unmarshal([]byte(plaintext), &config); err != nil {
- return nil, fmt.Errorf("unmarshal config: %w", err)
+ if err := json.Unmarshal([]byte(stored), &config); err == nil {
+ return config, nil
}
- return config, nil
+ // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
+ if len(lb.encryptionKey) == AES256KeySize {
+ //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
+ if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil {
+ if err := json.Unmarshal([]byte(plaintext), &config); err == nil {
+ return config, nil
+ }
+ }
+ }
+ slog.Warn("payment provider config unreadable, treating as empty for re-entry",
+ "stored_len", len(stored))
+ return nil, nil
}
// GetInstanceDailyAmount returns the total completed order amount for an instance today.
@@ -321,14 +380,45 @@ func InstanceSupportsType(supportedTypes string, target PaymentType) bool {
if supportedTypes == "" {
return true
}
+ normalizedTarget := normalizeVisibleMethodSupportType(target)
for _, t := range strings.Split(supportedTypes, ",") {
- if strings.TrimSpace(t) == target {
+ supported := strings.TrimSpace(t)
+ if supported == target || normalizeVisibleMethodSupportType(supported) == normalizedTarget {
return true
}
}
return false
}
+func normalizeVisibleMethodSupportType(paymentType PaymentType) PaymentType {
+ switch strings.TrimSpace(paymentType) {
+ case TypeAlipay, TypeAlipayDirect:
+ return TypeAlipay
+ case TypeWxpay, TypeWxpayDirect:
+ return TypeWxpay
+ default:
+ return strings.TrimSpace(paymentType)
+ }
+}
+
+func legacyVisibleMethodAlias(paymentType PaymentType) PaymentType {
+ switch normalizeVisibleMethodSupportType(paymentType) {
+ case TypeAlipay:
+ return TypeAlipayDirect
+ case TypeWxpay:
+ return TypeWxpayDirect
+ default:
+ return ""
+ }
+}
+
+func resolveWxpayJSAPIAppID(config map[string]string) string {
+ if appID := strings.TrimSpace(config["mpAppId"]); appID != "" {
+ return appID
+ }
+ return strings.TrimSpace(config["appId"])
+}
+
// GetInstanceConfig decrypts and returns the configuration for a provider instance by ID.
func (lb *DefaultLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
inst, err := lb.db.PaymentProviderInstance.Get(ctx, instanceID)
diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go
index 04b3c25b..ed08a7dd 100644
--- a/backend/internal/payment/load_balancer_test.go
+++ b/backend/internal/payment/load_balancer_test.go
@@ -68,10 +68,16 @@ func TestInstanceSupportsType(t *testing.T) {
expected: true,
},
{
- name: "partial match should not succeed",
+ name: "legacy alipay direct supports canonical visible method",
supportedTypes: "alipay_direct",
target: "alipay",
- expected: false,
+ expected: true,
+ },
+ {
+ name: "legacy wxpay direct supports canonical visible method",
+ supportedTypes: "wxpay_direct",
+ target: "wxpay",
+ expected: true,
},
{
name: "empty supported types means all supported",
@@ -92,6 +98,22 @@ func TestInstanceSupportsType(t *testing.T) {
}
}
+func TestGetInstanceChannelLimitsFallsBackToLegacyDirectAliases(t *testing.T) {
+ t.Parallel()
+
+ inst := testInstance(1, TypeAlipay, makeLimitsJSON(TypeAlipayDirect, ChannelLimits{SingleMax: 66}))
+ got := getInstanceChannelLimits(inst, TypeAlipay)
+ if got.SingleMax != 66 {
+ t.Fatalf("getInstanceChannelLimits() = %+v, want SingleMax=66", got)
+ }
+
+ wxInst := testInstance(2, TypeWxpay, makeLimitsJSON(TypeWxpayDirect, ChannelLimits{SingleMin: 8}))
+ wxGot := getInstanceChannelLimits(wxInst, TypeWxpay)
+ if wxGot.SingleMin != 8 {
+ t.Fatalf("getInstanceChannelLimits() = %+v, want SingleMin=8", wxGot)
+ }
+}
+
// ---------------------------------------------------------------------------
// Helper to build test PaymentProviderInstance values
// ---------------------------------------------------------------------------
@@ -452,6 +474,103 @@ func TestStartOfDay(t *testing.T) {
}
}
+func TestDecryptConfig_PlaintextAndLegacyCompat(t *testing.T) {
+ t.Parallel()
+
+ key := make([]byte, AES256KeySize)
+ for i := range key {
+ key[i] = byte(i + 1)
+ }
+ wrongKey := make([]byte, AES256KeySize)
+ for i := range wrongKey {
+ wrongKey[i] = byte(0xFF - i)
+ }
+
+ plaintextJSON := `{"appId":"app-123","secret":"sec-xyz"}`
+
+ legacyEncrypted, err := Encrypt(plaintextJSON, key)
+ if err != nil {
+ t.Fatalf("seed Encrypt: %v", err)
+ }
+
+ tests := []struct {
+ name string
+ stored string
+ key []byte
+ want map[string]string
+ }{
+ {
+ name: "empty stored returns nil map",
+ stored: "",
+ key: key,
+ want: nil,
+ },
+ {
+ name: "plaintext JSON parses directly",
+ stored: plaintextJSON,
+ key: nil,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "plaintext JSON works even with key present",
+ stored: plaintextJSON,
+ key: key,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "legacy ciphertext with correct key decrypts",
+ stored: legacyEncrypted,
+ key: key,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "legacy ciphertext with no key treated as empty",
+ stored: legacyEncrypted,
+ key: nil,
+ want: nil,
+ },
+ {
+ name: "legacy ciphertext with wrong key treated as empty",
+ stored: legacyEncrypted,
+ key: wrongKey,
+ want: nil,
+ },
+ {
+ name: "garbage data treated as empty",
+ stored: "not-json-and-not-ciphertext",
+ key: key,
+ want: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ lb := NewDefaultLoadBalancer(nil, tt.key)
+ got, err := lb.decryptConfig(tt.stored)
+ if err != nil {
+ t.Fatalf("decryptConfig unexpected error: %v", err)
+ }
+ if !stringMapEqual(got, tt.want) {
+ t.Fatalf("decryptConfig = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+// stringMapEqual compares two map[string]string values; nil and empty are equal.
+func stringMapEqual(a, b map[string]string) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for k, v := range a {
+ if bv, ok := b[k]; !ok || bv != v {
+ return false
+ }
+ }
+ return true
+}
+
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go
index af8a90c6..1234b568 100644
--- a/backend/internal/payment/provider/alipay.go
+++ b/backend/internal/payment/provider/alipay.go
@@ -15,8 +15,9 @@ import (
// Alipay product codes.
const (
- alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
- alipayProductCodeWapPay = "QUICK_WAP_WAY"
+ alipayProductCodePreCreate = "FACE_TO_FACE_PAYMENT"
+ alipayProductCodeWapPay = "QUICK_WAP_WAY"
+ alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
)
// Alipay response constants.
@@ -26,6 +27,18 @@ const (
alipayRefundSuffix = "-refund"
)
+var (
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ return client.TradeWapPay(param)
+ }
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ return client.TradePreCreate(ctx, param)
+ }
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ return client.TradePagePay(param)
+ }
+)
+
// Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK.
type Alipay struct {
instanceID string
@@ -79,8 +92,24 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay}
}
-// CreatePayment creates an Alipay payment page URL.
-func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+func (a *Alipay) MerchantIdentityMetadata() map[string]string {
+ if a == nil {
+ return nil
+ }
+ appID := strings.TrimSpace(a.config["appId"])
+ if appID == "" {
+ return nil
+ }
+ return map[string]string{"app_id": appID}
+}
+
+// CreatePayment creates an Alipay payment using the following routing:
+// - Mobile (H5): alipay.trade.wap.pay — browser redirect into Alipay.
+// - Desktop: prefer alipay.trade.precreate to get a scan payload directly.
+// - Desktop fallback: if precreate is unavailable for the merchant, fall back
+// to alipay.trade.page.pay and expose both pay_url and qr_code so the
+// frontend can render a QR while still allowing direct page open.
+func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient()
if err != nil {
return nil, err
@@ -96,31 +125,73 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque
}
if req.IsMobile {
- return a.createTrade(client, req, notifyURL, returnURL, true)
+ return a.createWapTrade(client, req, notifyURL, returnURL)
}
- return a.createTrade(client, req, notifyURL, returnURL, false)
+ return a.createDesktopTrade(ctx, client, req, notifyURL, returnURL)
}
-func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) {
- if isMobile {
- param := alipay.TradeWapPay{}
- param.OutTradeNo = req.OrderID
- param.TotalAmount = req.Amount
- param.Subject = req.Subject
- param.ProductCode = alipayProductCodeWapPay
- param.NotifyURL = notifyURL
- param.ReturnURL = returnURL
+func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
+ param := alipay.TradeWapPay{}
+ param.OutTradeNo = req.OrderID
+ param.TotalAmount = req.Amount
+ param.Subject = req.Subject
+ param.ProductCode = alipayProductCodeWapPay
+ param.NotifyURL = notifyURL
+ param.ReturnURL = returnURL
- payURL, err := client.TradeWapPay(param)
- if err != nil {
- return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
- }
- return &payment.CreatePaymentResponse{
- TradeNo: req.OrderID,
- PayURL: payURL.String(),
- }, nil
+ payURL, err := alipayTradeWapPay(client, param)
+ if err != nil {
+ return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
+ }
+ return &payment.CreatePaymentResponse{
+ TradeNo: req.OrderID,
+ PayURL: payURL.String(),
+ }, nil
+}
+
+func (a *Alipay) createDesktopTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
+ resp, precreateErr := a.createPrecreateTrade(ctx, client, req, notifyURL)
+ if precreateErr == nil {
+ return resp, nil
}
+ resp, pagePayErr := a.createPagePayTrade(client, req, notifyURL, returnURL)
+ if pagePayErr == nil {
+ return resp, nil
+ }
+
+ return nil, fmt.Errorf("alipay desktop payment failed: precreate=%v; pagepay=%w", precreateErr, pagePayErr)
+}
+
+func (a *Alipay) createPrecreateTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL string) (*payment.CreatePaymentResponse, error) {
+ param := alipay.TradePreCreate{}
+ param.OutTradeNo = req.OrderID
+ param.TotalAmount = req.Amount
+ param.Subject = req.Subject
+ param.ProductCode = alipayProductCodePreCreate
+ param.NotifyURL = notifyURL
+
+ rsp, err := alipayTradePreCreate(ctx, client, param)
+ if err != nil {
+ return nil, fmt.Errorf("alipay TradePreCreate: %w", err)
+ }
+ if rsp == nil {
+ return nil, fmt.Errorf("alipay TradePreCreate: empty response")
+ }
+ if rsp.IsFailure() {
+ return nil, fmt.Errorf("alipay TradePreCreate failed: %s", rsp.Error.Error())
+ }
+ if strings.TrimSpace(rsp.QRCode) == "" {
+ return nil, fmt.Errorf("alipay TradePreCreate: empty qr_code")
+ }
+
+ return &payment.CreatePaymentResponse{
+ TradeNo: req.OrderID,
+ QRCode: rsp.QRCode,
+ }, nil
+}
+
+func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
param := alipay.TradePagePay{}
param.OutTradeNo = req.OrderID
param.TotalAmount = req.Amount
@@ -129,7 +200,7 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq
param.NotifyURL = notifyURL
param.ReturnURL = returnURL
- payURL, err := client.TradePagePay(param)
+ payURL, err := alipayTradePagePay(client, param)
if err != nil {
return nil, fmt.Errorf("alipay TradePagePay: %w", err)
}
@@ -168,14 +239,23 @@ func (a *Alipay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Query
amount, err := strconv.ParseFloat(result.TotalAmount, 64)
if err != nil {
- return nil, fmt.Errorf("alipay parse amount %q: %w", result.TotalAmount, err)
+ amount, err = parseAlipayAmount(
+ result.TotalAmount,
+ result.ReceiptAmount,
+ result.BuyerPayAmount,
+ result.InvoiceAmount,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("alipay parse amount: %w", err)
+ }
}
return &payment.QueryOrderResponse{
- TradeNo: result.TradeNo,
- Status: status,
- Amount: amount,
- PaidAt: result.SendPayDate,
+ TradeNo: result.TradeNo,
+ Status: status,
+ Amount: amount,
+ PaidAt: result.SendPayDate,
+ Metadata: a.MerchantIdentityMetadata(),
}, nil
}
@@ -203,15 +283,31 @@ func (a *Alipay) VerifyNotification(ctx context.Context, rawBody string, _ map[s
amount, err := strconv.ParseFloat(notification.TotalAmount, 64)
if err != nil {
- return nil, fmt.Errorf("alipay parse notification amount %q: %w", notification.TotalAmount, err)
+ amount, err = parseAlipayAmount(
+ notification.TotalAmount,
+ notification.ReceiptAmount,
+ notification.BuyerPayAmount,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("alipay parse notification amount: %w", err)
+ }
+ }
+
+ metadata := a.MerchantIdentityMetadata()
+ if appID := strings.TrimSpace(notification.AppId); appID != "" {
+ if metadata == nil {
+ metadata = map[string]string{}
+ }
+ metadata["app_id"] = appID
}
return &payment.PaymentNotification{
- TradeNo: notification.TradeNo,
- OrderID: notification.OutTradeNo,
- Amount: amount,
- Status: status,
- RawData: rawBody,
+ TradeNo: notification.TradeNo,
+ OrderID: notification.OutTradeNo,
+ Amount: amount,
+ Status: status,
+ RawData: rawBody,
+ Metadata: metadata,
}, nil
}
@@ -272,8 +368,23 @@ func isTradeNotExist(err error) bool {
return strings.Contains(err.Error(), alipayErrTradeNotExist)
}
+func parseAlipayAmount(values ...string) (float64, error) {
+ for _, raw := range values {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ continue
+ }
+ amount, err := strconv.ParseFloat(raw, 64)
+ if err == nil {
+ return amount, nil
+ }
+ }
+ return 0, fmt.Errorf("no valid amount field")
+}
+
// Ensure interface compliance.
var (
- _ payment.Provider = (*Alipay)(nil)
- _ payment.CancelableProvider = (*Alipay)(nil)
+ _ payment.Provider = (*Alipay)(nil)
+ _ payment.CancelableProvider = (*Alipay)(nil)
+ _ payment.MerchantIdentityProvider = (*Alipay)(nil)
)
diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go
index 7b0ce0d8..fdc8eec1 100644
--- a/backend/internal/payment/provider/alipay_test.go
+++ b/backend/internal/payment/provider/alipay_test.go
@@ -3,9 +3,14 @@
package provider
import (
+ "context"
"errors"
+ "net/url"
"strings"
"testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/smartwalle/alipay/v3"
)
func TestIsTradeNotExist(t *testing.T) {
@@ -130,3 +135,173 @@ func TestNewAlipay(t *testing.T) {
})
}
}
+
+func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
+ origPreCreate := alipayTradePreCreate
+ origPagePay := alipayTradePagePay
+ origWapPay := alipayTradeWapPay
+ t.Cleanup(func() {
+ alipayTradePreCreate = origPreCreate
+ alipayTradePagePay = origPagePay
+ alipayTradeWapPay = origWapPay
+ })
+
+ preCreateCalls := 0
+ pagePayCalls := 0
+ wapPayCalls := 0
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ preCreateCalls++
+ return nil, errors.New("merchant does not have FACE_TO_FACE_PAYMENT")
+ }
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ pagePayCalls++
+ if param.OutTradeNo != "sub2_100" {
+ t.Fatalf("out_trade_no = %q, want %q", param.OutTradeNo, "sub2_100")
+ }
+ if param.NotifyURL != "https://merchant.example.com/api/v1/payment/webhook/alipay" {
+ t.Fatalf("notify_url = %q", param.NotifyURL)
+ }
+ return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
+ }
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ wapPayCalls++
+ return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay")
+ }
+
+ provider := &Alipay{}
+ resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_100",
+ Amount: "88.00",
+ Subject: "Balance recharge",
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if preCreateCalls != 1 {
+ t.Fatalf("precreate calls = %d, want 1", preCreateCalls)
+ }
+ if pagePayCalls != 1 {
+ t.Fatalf("page pay calls = %d, want 1", pagePayCalls)
+ }
+ if wapPayCalls != 0 {
+ t.Fatalf("wap pay calls = %d, want 0", wapPayCalls)
+ }
+ if resp.PayURL == "" {
+ t.Fatal("expected pay_url for desktop page pay")
+ }
+ if resp.QRCode != resp.PayURL {
+ t.Fatalf("qr_code = %q, want same as pay_url %q", resp.QRCode, resp.PayURL)
+ }
+}
+
+func TestCreateTradeUsesWapPayForMobile(t *testing.T) {
+ origWapPay := alipayTradeWapPay
+ t.Cleanup(func() {
+ alipayTradeWapPay = origWapPay
+ })
+
+ wapPayCalls := 0
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ wapPayCalls++
+ if param.ReturnURL != "https://merchant.example.com/payment/result" {
+ t.Fatalf("return_url = %q", param.ReturnURL)
+ }
+ return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay")
+ }
+
+ provider := &Alipay{}
+ resp, err := provider.createWapTrade(&alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_101",
+ Amount: "18.00",
+ Subject: "Balance recharge",
+ IsMobile: true,
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if wapPayCalls != 1 {
+ t.Fatalf("wap pay calls = %d, want 1", wapPayCalls)
+ }
+ if resp.PayURL == "" {
+ t.Fatal("expected pay_url for mobile wap pay")
+ }
+}
+
+func TestCreateTradeUsesPrecreateForDesktopWhenAvailable(t *testing.T) {
+ origPreCreate := alipayTradePreCreate
+ origPagePay := alipayTradePagePay
+ t.Cleanup(func() {
+ alipayTradePreCreate = origPreCreate
+ alipayTradePagePay = origPagePay
+ })
+
+ preCreateCalls := 0
+ pagePayCalls := 0
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ preCreateCalls++
+ if param.ProductCode != alipayProductCodePreCreate {
+ t.Fatalf("product_code = %q, want %q", param.ProductCode, alipayProductCodePreCreate)
+ }
+ return &alipay.TradePreCreateRsp{
+ Error: alipay.Error{Code: alipay.CodeSuccess},
+ QRCode: "https://qr.alipay.example.com/precreate-token",
+ }, nil
+ }
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ pagePayCalls++
+ return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
+ }
+
+ provider := &Alipay{}
+ resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_102",
+ Amount: "66.00",
+ Subject: "Balance recharge",
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if preCreateCalls != 1 {
+ t.Fatalf("precreate calls = %d, want 1", preCreateCalls)
+ }
+ if pagePayCalls != 0 {
+ t.Fatalf("page pay calls = %d, want 0", pagePayCalls)
+ }
+ if resp.QRCode != "https://qr.alipay.example.com/precreate-token" {
+ t.Fatalf("qr_code = %q", resp.QRCode)
+ }
+ if resp.PayURL != "" {
+ t.Fatalf("pay_url = %q, want empty for precreate", resp.PayURL)
+ }
+}
+
+func TestAlipayMerchantIdentityMetadata(t *testing.T) {
+ t.Parallel()
+
+ provider := &Alipay{
+ config: map[string]string{
+ "appId": "2021001234567890",
+ },
+ }
+
+ metadata := provider.MerchantIdentityMetadata()
+ if metadata["app_id"] != "2021001234567890" {
+ t.Fatalf("app_id = %q, want %q", metadata["app_id"], "2021001234567890")
+ }
+}
+
+func TestParseAlipayAmount(t *testing.T) {
+ t.Parallel()
+
+ amount, err := parseAlipayAmount("", "88.00", "77.00")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if amount != 88 {
+ t.Fatalf("amount = %v, want 88", amount)
+ }
+
+ if _, err := parseAlipayAmount("", "not-a-number"); err == nil {
+ t.Fatal("expected error when no valid amount field exists")
+ }
+}
diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go
index e33a567d..e7d8aab9 100644
--- a/backend/internal/payment/provider/easypay.go
+++ b/backend/internal/payment/provider/easypay.go
@@ -25,6 +25,7 @@ const (
easypayStatusPaid = 1
easypayHTTPTimeout = 10 * time.Second
maxEasypayResponseSize = 1 << 20 // 1MB
+ maxEasypayErrorSummary = 512
tradeStatusSuccess = "TRADE_SUCCESS"
signTypeMD5 = "MD5"
paymentModePopup = "popup"
@@ -42,23 +43,72 @@ type EasyPay struct {
// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) {
for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} {
- if config[k] == "" {
+ if strings.TrimSpace(config[k]) == "" {
return nil, fmt.Errorf("easypay config missing required key: %s", k)
}
}
+ cfg := make(map[string]string, len(config))
+ for k, v := range config {
+ cfg[k] = v
+ }
+ cfg["apiBase"] = normalizeEasyPayAPIBase(cfg["apiBase"])
return &EasyPay{
instanceID: instanceID,
- config: config,
+ config: cfg,
httpClient: &http.Client{Timeout: easypayHTTPTimeout},
}, nil
}
+func normalizeEasyPayAPIBase(apiBase string) string {
+ base := strings.TrimSpace(apiBase)
+ if base == "" {
+ return ""
+ }
+ if parsed, err := url.Parse(base); err == nil && parsed.Scheme != "" && parsed.Host != "" {
+ parsed.RawQuery = ""
+ parsed.Fragment = ""
+ parsed.RawPath = ""
+ parsed.Path = trimEasyPayEndpointPath(parsed.Path)
+ return strings.TrimRight(parsed.String(), "/")
+ }
+ return strings.TrimRight(trimEasyPayEndpointPath(base), "/")
+}
+
+func trimEasyPayEndpointPath(path string) string {
+ path = strings.TrimRight(strings.TrimSpace(path), "/")
+ lower := strings.ToLower(path)
+ for _, endpoint := range []string{"/submit.php", "/mapi.php", "/api.php"} {
+ if strings.HasSuffix(lower, endpoint) {
+ return strings.TrimRight(path[:len(path)-len(endpoint)], "/")
+ }
+ }
+ return path
+}
+
+func (e *EasyPay) apiBase() string {
+ if e == nil {
+ return ""
+ }
+ return normalizeEasyPayAPIBase(e.config["apiBase"])
+}
+
func (e *EasyPay) Name() string { return "EasyPay" }
func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay }
func (e *EasyPay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay, payment.TypeWxpay}
}
+func (e *EasyPay) MerchantIdentityMetadata() map[string]string {
+ if e == nil {
+ return nil
+ }
+ pid := strings.TrimSpace(e.config["pid"])
+ if pid == "" {
+ return nil
+ }
+ return map[string]string{"pid": pid}
+}
+
func (e *EasyPay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
// Payment mode determined by instance config, not payment type.
// "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php).
@@ -93,8 +143,7 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym
for k, v := range params {
q.Set(k, v)
}
- base := strings.TrimRight(e.config["apiBase"], "/")
- payURL := base + "/submit.php?" + q.Encode()
+ payURL := e.apiBase() + "/submit.php?" + q.Encode()
return &payment.CreatePaymentResponse{PayURL: payURL}, nil
}
@@ -116,7 +165,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen
params["sign"] = easyPaySign(params, e.config["pkey"])
params["sign_type"] = signTypeMD5
- body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params)
+ body, err := e.post(ctx, e.apiBase()+"/mapi.php", params)
if err != nil {
return nil, fmt.Errorf("easypay create: %w", err)
}
@@ -160,7 +209,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
"act": "order", "pid": e.config["pid"],
"key": e.config["pkey"], "out_trade_no": tradeNo,
}
- body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params)
+ body, err := e.post(ctx, e.apiBase()+"/api.php", params)
if err != nil {
return nil, fmt.Errorf("easypay query: %w", err)
}
@@ -178,7 +227,12 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer
status = payment.ProviderStatusPaid
}
amount, _ := strconv.ParseFloat(resp.Money, 64)
- return &payment.QueryOrderResponse{TradeNo: tradeNo, Status: status, Amount: amount}, nil
+ return &payment.QueryOrderResponse{
+ TradeNo: tradeNo,
+ Status: status,
+ Amount: amount,
+ Metadata: e.MerchantIdentityMetadata(),
+ }, nil
}
func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
@@ -203,32 +257,143 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st
status = payment.ProviderStatusSuccess
}
amount, _ := strconv.ParseFloat(params["money"], 64)
+
+ metadata := e.MerchantIdentityMetadata()
+ if pid := strings.TrimSpace(params["pid"]); pid != "" {
+ if metadata == nil {
+ metadata = map[string]string{}
+ }
+ metadata["pid"] = pid
+ }
return &payment.PaymentNotification{
TradeNo: params["trade_no"], OrderID: params["out_trade_no"],
- Amount: amount, Status: status, RawData: rawBody,
+ Amount: amount, Status: status, RawData: rawBody, Metadata: metadata,
}, nil
}
func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
- params := map[string]string{
- "pid": e.config["pid"], "key": e.config["pkey"],
- "trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount,
+ attempts := e.refundAttempts(req)
+ if len(attempts) == 0 {
+ return nil, fmt.Errorf("easypay refund missing order identifier")
}
- body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params)
- if err != nil {
- return nil, fmt.Errorf("easypay refund: %w", err)
+ var firstErr error
+ for i, attempt := range attempts {
+ body, status, err := e.postRaw(ctx, e.apiBase()+"/api.php?act=refund", attempt.params)
+ if err != nil {
+ return nil, fmt.Errorf("easypay refund request: %w", err)
+ }
+ if err := parseEasyPayRefundResponse(status, body); err != nil {
+ if firstErr == nil {
+ firstErr = err
+ }
+ if i+1 < len(attempts) && isEasyPayRefundOrderNotFound(err) {
+ continue
+ }
+ return nil, err
+ }
+ return &payment.RefundResponse{RefundID: attempt.refundID, Status: payment.ProviderStatusSuccess}, nil
}
+ return nil, firstErr
+}
+
+type easyPayRefundAttempt struct {
+ params map[string]string
+ refundID string
+}
+
+func (e *EasyPay) refundAttempts(req payment.RefundRequest) []easyPayRefundAttempt {
+ base := map[string]string{
+ "pid": e.config["pid"], "key": e.config["pkey"], "money": req.Amount,
+ }
+ var attempts []easyPayRefundAttempt
+ if orderID := strings.TrimSpace(req.OrderID); orderID != "" {
+ params := cloneStringMap(base)
+ params["out_trade_no"] = orderID
+ attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: orderID})
+ }
+ if tradeNo := strings.TrimSpace(req.TradeNo); tradeNo != "" {
+ params := cloneStringMap(base)
+ params["trade_no"] = tradeNo
+ attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: tradeNo})
+ }
+ return attempts
+}
+
+func cloneStringMap(in map[string]string) map[string]string {
+ out := make(map[string]string, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func isEasyPayRefundOrderNotFound(err error) bool {
+ if err == nil {
+ return false
+ }
+ msg := err.Error()
+ lower := strings.ToLower(msg)
+ return strings.Contains(msg, "订单编号不存在") ||
+ strings.Contains(msg, "订单不存在") ||
+ strings.Contains(lower, "order not found") ||
+ strings.Contains(lower, "not exist")
+}
+
+func parseEasyPayRefundResponse(status int, body []byte) error {
+ summary := summarizeEasyPayResponse(body)
+ if status < http.StatusOK || status >= http.StatusMultipleChoices {
+ return fmt.Errorf("easypay refund HTTP %d: %s", status, summary)
+ }
+
+ trimmed := strings.TrimSpace(string(body))
+ if trimmed == "" {
+ return fmt.Errorf("easypay refund empty response (HTTP %d): %s", status, summary)
+ }
+
+ lower := strings.ToLower(trimmed)
+ if strings.HasPrefix(lower, ""
+ }
+ if len(summary) > maxEasypayErrorSummary {
+ return summary[:maxEasypayErrorSummary] + "..."
+ }
+ return summary
}
func (e *EasyPay) resolveCID(paymentType string) string {
@@ -245,21 +410,34 @@ func (e *EasyPay) resolveCID(paymentType string) string {
}
func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) {
+ body, _, err := e.postRaw(ctx, endpoint, params)
+ return body, err
+}
+
+func (e *EasyPay) postRaw(ctx context.Context, endpoint string, params map[string]string) ([]byte, int, error) {
form := url.Values{}
for k, v := range params {
form.Set(k, v)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
- return nil, err
+ return nil, 0, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- resp, err := e.httpClient.Do(req)
+ client := e.httpClient
+ if client == nil {
+ client = &http.Client{Timeout: easypayHTTPTimeout}
+ }
+ resp, err := client.Do(req)
if err != nil {
- return nil, err
+ return nil, 0, err
}
defer func() { _ = resp.Body.Close() }()
- return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
+ if err != nil {
+ return nil, resp.StatusCode, err
+ }
+ return body, resp.StatusCode, nil
}
func easyPaySign(params map[string]string, pkey string) string {
diff --git a/backend/internal/payment/provider/easypay_refund_test.go b/backend/internal/payment/provider/easypay_refund_test.go
new file mode 100644
index 00000000..9e0e4942
--- /dev/null
+++ b/backend/internal/payment/provider/easypay_refund_test.go
@@ -0,0 +1,196 @@
+package provider
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+func TestNormalizeEasyPayAPIBase(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ input string
+ want string
+ }{
+ {input: "https://zpayz.cn", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/mapi.php", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/submit.php", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/api.php", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/api.php?act=refund", want: "https://zpayz.cn"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ t.Parallel()
+ if got := normalizeEasyPayAPIBase(tt.input); got != tt.want {
+ t.Fatalf("normalizeEasyPayAPIBase(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestEasyPayRefundNormalizesAPIBaseAndSendsOutTradeNoOnly(t *testing.T) {
+ t.Parallel()
+
+ var gotPath string
+ var gotQuery url.Values
+ var gotForm url.Values
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotPath = r.URL.Path
+ gotQuery = r.URL.Query()
+ if err := r.ParseForm(); err != nil {
+ t.Errorf("ParseForm: %v", err)
+ }
+ gotForm = r.PostForm
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
+ }))
+ defer server.Close()
+
+ provider := newTestEasyPay(t, server.URL+"/mapi.php")
+ resp, err := provider.Refund(context.Background(), payment.RefundRequest{
+ TradeNo: "trade-123",
+ OrderID: "out-456",
+ Amount: "1.50",
+ })
+ if err != nil {
+ t.Fatalf("Refund returned error: %v", err)
+ }
+ if resp == nil || resp.Status != payment.ProviderStatusSuccess {
+ t.Fatalf("Refund response = %+v, want success", resp)
+ }
+ if gotPath != "/api.php" {
+ t.Fatalf("refund path = %q, want /api.php", gotPath)
+ }
+ if gotQuery.Get("act") != "refund" {
+ t.Fatalf("refund act query = %q, want refund", gotQuery.Get("act"))
+ }
+ for key, want := range map[string]string{
+ "pid": "pid-1",
+ "key": "pkey-1",
+ "out_trade_no": "out-456",
+ "money": "1.50",
+ } {
+ if got := gotForm.Get(key); got != want {
+ t.Fatalf("form[%s] = %q, want %q (form=%v)", key, got, want, gotForm)
+ }
+ }
+ if got := gotForm.Get("trade_no"); got != "" {
+ t.Fatalf("form[trade_no] = %q, want empty (form=%v)", got, gotForm)
+ }
+}
+
+func TestEasyPayRefundRetriesWithTradeNoWhenOutTradeNoNotFound(t *testing.T) {
+ t.Parallel()
+
+ var gotForms []url.Values
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api.php" {
+ t.Errorf("refund path = %q, want /api.php", r.URL.Path)
+ }
+ if r.URL.Query().Get("act") != "refund" {
+ t.Errorf("refund act query = %q, want refund", r.URL.Query().Get("act"))
+ }
+ if err := r.ParseForm(); err != nil {
+ t.Errorf("ParseForm: %v", err)
+ }
+ gotForms = append(gotForms, r.PostForm)
+ w.Header().Set("Content-Type", "application/json")
+ if len(gotForms) == 1 {
+ _, _ = w.Write([]byte(`{"code":0,"msg":"订单编号不存在!"}`))
+ return
+ }
+ _, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
+ }))
+ defer server.Close()
+
+ provider := newTestEasyPay(t, server.URL+"/mapi.php")
+ resp, err := provider.Refund(context.Background(), payment.RefundRequest{
+ TradeNo: "trade-123",
+ OrderID: "out-456",
+ Amount: "1.50",
+ })
+ if err != nil {
+ t.Fatalf("Refund returned error: %v", err)
+ }
+ if resp == nil || resp.Status != payment.ProviderStatusSuccess || resp.RefundID != "trade-123" {
+ t.Fatalf("Refund response = %+v, want success with trade refund id", resp)
+ }
+ if len(gotForms) != 2 {
+ t.Fatalf("refund attempts = %d, want 2", len(gotForms))
+ }
+ if got := gotForms[0].Get("out_trade_no"); got != "out-456" {
+ t.Fatalf("first form[out_trade_no] = %q, want out-456 (form=%v)", got, gotForms[0])
+ }
+ if got := gotForms[0].Get("trade_no"); got != "" {
+ t.Fatalf("first form[trade_no] = %q, want empty (form=%v)", got, gotForms[0])
+ }
+ if got := gotForms[1].Get("trade_no"); got != "trade-123" {
+ t.Fatalf("second form[trade_no] = %q, want trade-123 (form=%v)", got, gotForms[1])
+ }
+ if got := gotForms[1].Get("out_trade_no"); got != "" {
+ t.Fatalf("second form[out_trade_no] = %q, want empty (form=%v)", got, gotForms[1])
+ }
+}
+
+func TestEasyPayRefundResponseErrors(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ statusCode int
+ body string
+ want string
+ }{
+ {name: "html response", statusCode: http.StatusOK, body: "bad config", want: "non-JSON response (HTTP 200): bad config"},
+ {name: "non json response", statusCode: http.StatusOK, body: "not json", want: "non-JSON response (HTTP 200): not json"},
+ {name: "non 2xx response", statusCode: http.StatusBadGateway, body: "bad gateway", want: "HTTP 502: bad gateway"},
+ {name: "empty response", statusCode: http.StatusOK, body: "", want: "empty response (HTTP 200): "},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(tt.statusCode)
+ _, _ = w.Write([]byte(tt.body))
+ }))
+ defer server.Close()
+
+ provider := newTestEasyPay(t, server.URL)
+ _, err := provider.Refund(context.Background(), payment.RefundRequest{
+ OrderID: "out-456",
+ Amount: "1.50",
+ })
+ if err == nil {
+ t.Fatal("Refund returned nil error")
+ }
+ if !strings.Contains(err.Error(), tt.want) {
+ t.Fatalf("Refund error = %q, want substring %q", err.Error(), tt.want)
+ }
+ })
+ }
+}
+
+func newTestEasyPay(t *testing.T, apiBase string) *EasyPay {
+ t.Helper()
+
+ provider, err := NewEasyPay("test-instance", map[string]string{
+ "pid": "pid-1",
+ "pkey": "pkey-1",
+ "apiBase": apiBase,
+ "notifyUrl": "https://example.com/notify",
+ "returnUrl": "https://example.com/return",
+ })
+ if err != nil {
+ t.Fatalf("NewEasyPay: %v", err)
+ }
+ return provider
+}
diff --git a/backend/internal/payment/provider/easypay_sign_test.go b/backend/internal/payment/provider/easypay_sign_test.go
index 146a6fa1..8328d294 100644
--- a/backend/internal/payment/provider/easypay_sign_test.go
+++ b/backend/internal/payment/provider/easypay_sign_test.go
@@ -178,3 +178,18 @@ func TestEasyPayVerifySignWrongSignValue(t *testing.T) {
t.Fatal("easyPayVerifySign should return false for an incorrect sign value")
}
}
+
+func TestEasyPayMerchantIdentityMetadata(t *testing.T) {
+ t.Parallel()
+
+ provider := &EasyPay{
+ config: map[string]string{
+ "pid": "1001",
+ },
+ }
+
+ metadata := provider.MerchantIdentityMetadata()
+ if metadata["pid"] != "1001" {
+ t.Fatalf("pid = %q, want %q", metadata["pid"], "1001")
+ }
+}
diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go
index 0b41c4fb..e6291dd3 100644
--- a/backend/internal/payment/provider/wxpay.go
+++ b/backend/internal/payment/provider/wxpay.go
@@ -3,22 +3,24 @@ package provider
import (
"bytes"
"context"
- "crypto/rsa"
"fmt"
"io"
- "log/slog"
"net/http"
+ "net/url"
+ "strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/wechatpay-apiv3/wechatpay-go/core"
"github.com/wechatpay-apiv3/wechatpay-go/core/auth/verifiers"
"github.com/wechatpay-apiv3/wechatpay-go/core/notify"
"github.com/wechatpay-apiv3/wechatpay-go/core/option"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
"github.com/wechatpay-apiv3/wechatpay-go/services/refunddomestic"
"github.com/wechatpay-apiv3/wechatpay-go/utils"
@@ -26,8 +28,23 @@ import (
// WeChat Pay constants.
const (
- wxpayCurrency = "CNY"
- wxpayH5Type = "Wap"
+ wxpayCurrency = "CNY"
+ wxpayH5Type = "Wap"
+ wxpayResultPath = "/payment/result"
+)
+
+const (
+ wxpayMetadataAppID = "appid"
+ wxpayMetadataMerchantID = "mchid"
+ wxpayMetadataCurrency = "currency"
+ wxpayMetadataTradeState = "trade_state"
+)
+
+// WeChat Pay create-payment modes.
+const (
+ wxpayModeNative = "native"
+ wxpayModeH5 = "h5"
+ wxpayModeJSAPI = "jsapi"
)
// WeChat Pay trade states.
@@ -43,9 +60,16 @@ const (
wxpayEventTransactionSuccess = "TRANSACTION.SUCCESS"
)
-// WeChat Pay error codes.
-const (
- wxpayErrNoAuth = "NO_AUTH"
+var (
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ return svc.Prepay(ctx, req)
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ return svc.Prepay(ctx, req)
+ }
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ return svc.PrepayWithRequestPayment(ctx, req)
+ }
)
type Wxpay struct {
@@ -56,15 +80,35 @@ type Wxpay struct {
notifyHandler *notify.Handler
}
+const wxpayAPIv3KeyLength = 32
+
func NewWxpay(instanceID string, config map[string]string) (*Wxpay, error) {
- required := []string{"appId", "mchId", "privateKey", "apiV3Key", "publicKey", "publicKeyId", "certSerial"}
+ // All fields are required. Platform-certificate mode is intentionally unsupported —
+ // WeChat has been migrating all merchants to the pubkey verifier since 2024-10,
+ // and newly-provisioned merchants cannot download platform certificates at all.
+ required := []string{"appId", "mchId", "privateKey", "apiV3Key", "certSerial", "publicKey", "publicKeyId"}
for _, k := range required {
if config[k] == "" {
- return nil, fmt.Errorf("wxpay config missing required key: %s", k)
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_MISSING_KEY", "missing_required_key").
+ WithMetadata(map[string]string{"key": k})
}
}
- if len(config["apiV3Key"]) != 32 {
- return nil, fmt.Errorf("wxpay apiV3Key must be exactly 32 bytes, got %d", len(config["apiV3Key"]))
+ if len(config["apiV3Key"]) != wxpayAPIv3KeyLength {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY_LENGTH", "invalid_key_length").
+ WithMetadata(map[string]string{
+ "key": "apiV3Key",
+ "expected": strconv.Itoa(wxpayAPIv3KeyLength),
+ "actual": strconv.Itoa(len(config["apiV3Key"])),
+ })
+ }
+ // Parse PEMs eagerly so malformed keys surface at save time, not at order creation.
+ if _, err := utils.LoadPrivateKey(formatPEM(config["privateKey"], "PRIVATE KEY")); err != nil {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
+ WithMetadata(map[string]string{"key": "privateKey"})
+ }
+ if _, err := utils.LoadPublicKey(formatPEM(config["publicKey"], "PUBLIC KEY")); err != nil {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
+ WithMetadata(map[string]string{"key": "publicKey"})
}
return &Wxpay{instanceID: instanceID, config: config}, nil
}
@@ -75,6 +119,16 @@ func (w *Wxpay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeWxpay}
}
+// ResolveWxpayJSAPIAppID returns the AppID that JSAPI prepay will use for a
+// given provider config. A dedicated MP AppID takes precedence over the base
+// merchant AppID.
+func ResolveWxpayJSAPIAppID(config map[string]string) string {
+ if appID := strings.TrimSpace(config["mpAppId"]); appID != "" {
+ return appID
+ }
+ return strings.TrimSpace(config["appId"])
+}
+
func formatPEM(key, keyType string) string {
key = strings.TrimSpace(key)
if strings.HasPrefix(key, "-----BEGIN") {
@@ -89,14 +143,19 @@ func (w *Wxpay) ensureClient() (*core.Client, error) {
if w.coreClient != nil {
return w.coreClient, nil
}
- privateKey, publicKey, err := w.loadKeyPair()
+ privateKey, err := utils.LoadPrivateKey(formatPEM(w.config["privateKey"], "PRIVATE KEY"))
if err != nil {
- return nil, err
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
+ WithMetadata(map[string]string{"key": "privateKey"})
+ }
+ publicKey, err := utils.LoadPublicKey(formatPEM(w.config["publicKey"], "PUBLIC KEY"))
+ if err != nil {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
+ WithMetadata(map[string]string{"key": "publicKey"})
}
- certSerial := w.config["certSerial"]
verifier := verifiers.NewSHA256WithRSAPubkeyVerifier(w.config["publicKeyId"], *publicKey)
client, err := core.NewClient(context.Background(),
- option.WithMerchantCredential(w.config["mchId"], certSerial, privateKey),
+ option.WithMerchantCredential(w.config["mchId"], w.config["certSerial"], privateKey),
option.WithVerifier(verifier))
if err != nil {
return nil, fmt.Errorf("wxpay init client: %w", err)
@@ -110,18 +169,6 @@ func (w *Wxpay) ensureClient() (*core.Client, error) {
return w.coreClient, nil
}
-func (w *Wxpay) loadKeyPair() (*rsa.PrivateKey, *rsa.PublicKey, error) {
- privateKey, err := utils.LoadPrivateKey(formatPEM(w.config["privateKey"], "PRIVATE KEY"))
- if err != nil {
- return nil, nil, fmt.Errorf("wxpay load private key: %w", err)
- }
- publicKey, err := utils.LoadPublicKey(formatPEM(w.config["publicKey"], "PUBLIC KEY"))
- if err != nil {
- return nil, nil, fmt.Errorf("wxpay load public key: %w", err)
- }
- return privateKey, publicKey, nil
-}
-
func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := w.ensureClient()
if err != nil {
@@ -139,30 +186,61 @@ func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequ
if err != nil {
return nil, fmt.Errorf("wxpay create payment: %w", err)
}
- if req.IsMobile && req.ClientIP != "" {
- resp, err := w.createOrder(ctx, client, req, notifyURL, totalFen, true)
- if err == nil {
- return resp, nil
- }
- if !strings.Contains(err.Error(), wxpayErrNoAuth) {
- return nil, err
- }
- slog.Warn("wxpay H5 payment not authorized, falling back to native", "order", req.OrderID)
+
+ mode, err := resolveWxpayCreateMode(req)
+ if err != nil {
+ return nil, err
+ }
+ switch mode {
+ case wxpayModeJSAPI:
+ return w.prepayJSAPI(ctx, client, req, notifyURL, totalFen)
+ case wxpayModeH5:
+ return w.prepayH5(ctx, client, req, notifyURL, totalFen)
+ case wxpayModeNative:
+ return w.prepayNative(ctx, client, req, notifyURL, totalFen)
+ default:
+ return nil, fmt.Errorf("wxpay create payment: unsupported mode %q", mode)
}
- return w.createOrder(ctx, client, req, notifyURL, totalFen, false)
}
-func (w *Wxpay) createOrder(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64, useH5 bool) (*payment.CreatePaymentResponse, error) {
- if useH5 {
- return w.prepayH5(ctx, c, req, notifyURL, totalFen)
+func (w *Wxpay) prepayJSAPI(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
+ svc := jsapi.JsapiApiService{Client: c}
+ cur := wxpayCurrency
+ appID := ResolveWxpayJSAPIAppID(w.config)
+ prepayReq := jsapi.PrepayRequest{
+ Appid: core.String(appID),
+ Mchid: core.String(w.config["mchId"]),
+ Description: core.String(req.Subject),
+ OutTradeNo: core.String(req.OrderID),
+ NotifyUrl: core.String(notifyURL),
+ Amount: &jsapi.Amount{Total: core.Int64(totalFen), Currency: &cur},
+ Payer: &jsapi.Payer{Openid: core.String(strings.TrimSpace(req.OpenID))},
}
- return w.prepayNative(ctx, c, req, notifyURL, totalFen)
+ if clientIP := strings.TrimSpace(req.ClientIP); clientIP != "" {
+ prepayReq.SceneInfo = &jsapi.SceneInfo{PayerClientIp: core.String(clientIP)}
+ }
+ resp, _, err := wxpayJSAPIPrepayWithRequestPayment(ctx, svc, prepayReq)
+ if err != nil {
+ return nil, fmt.Errorf("wxpay jsapi prepay: %w", err)
+ }
+ return &payment.CreatePaymentResponse{
+ TradeNo: req.OrderID,
+ ResultType: payment.CreatePaymentResultJSAPIReady,
+ JSAPI: &payment.WechatJSAPIPayload{
+ AppID: wxSV(resp.Appid),
+ TimeStamp: wxSV(resp.TimeStamp),
+ NonceStr: wxSV(resp.NonceStr),
+ Package: wxSV(resp.Package),
+ SignType: wxSV(resp.SignType),
+ PaySign: wxSV(resp.PaySign),
+ },
+ }, nil
}
func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
svc := native.NativeApiService{Client: c}
cur := wxpayCurrency
- resp, _, err := svc.Prepay(ctx, native.PrepayRequest{
+ resp, _, err := wxpayNativePrepay(ctx, svc, native.PrepayRequest{
Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
NotifyUrl: core.String(notifyURL),
@@ -181,13 +259,12 @@ func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.Cr
func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
svc := h5.H5ApiService{Client: c}
cur := wxpayCurrency
- tp := wxpayH5Type
- resp, _, err := svc.Prepay(ctx, h5.PrepayRequest{
+ resp, _, err := wxpayH5Prepay(ctx, svc, h5.PrepayRequest{
Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
NotifyUrl: core.String(notifyURL),
Amount: &h5.Amount{Total: core.Int64(totalFen), Currency: &cur},
- SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: &h5.H5Info{Type: &tp}},
+ SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: buildWxpayH5Info(w.config)},
})
if err != nil {
return nil, fmt.Errorf("wxpay h5 prepay: %w", err)
@@ -196,9 +273,77 @@ func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.Create
if resp.H5Url != nil {
h5URL = *resp.H5Url
}
+ h5URL, err = appendWxpayRedirectURL(h5URL, req)
+ if err != nil {
+ return nil, err
+ }
return &payment.CreatePaymentResponse{TradeNo: req.OrderID, PayURL: h5URL}, nil
}
+func buildWxpayH5Info(config map[string]string) *h5.H5Info {
+ tp := wxpayH5Type
+ info := &h5.H5Info{Type: &tp}
+ if appName := strings.TrimSpace(config["h5AppName"]); appName != "" {
+ info.AppName = core.String(appName)
+ }
+ if appURL := strings.TrimSpace(config["h5AppUrl"]); appURL != "" {
+ info.AppUrl = core.String(appURL)
+ }
+ return info
+}
+
+func resolveWxpayCreateMode(req payment.CreatePaymentRequest) (string, error) {
+ if strings.TrimSpace(req.OpenID) != "" {
+ return wxpayModeJSAPI, nil
+ }
+ if req.IsMobile {
+ if strings.TrimSpace(req.ClientIP) == "" {
+ return "", fmt.Errorf("wxpay H5 payment requires client IP")
+ }
+ return wxpayModeH5, nil
+ }
+ return wxpayModeNative, nil
+}
+
+func appendWxpayRedirectURL(h5URL string, req payment.CreatePaymentRequest) (string, error) {
+ h5URL = strings.TrimSpace(h5URL)
+ returnURL := strings.TrimSpace(req.ReturnURL)
+ if h5URL == "" || returnURL == "" {
+ return h5URL, nil
+ }
+
+ redirectURL, err := buildWxpayResultURL(returnURL, req)
+ if err != nil {
+ return "", err
+ }
+
+ sep := "&"
+ if !strings.Contains(h5URL, "?") {
+ sep = "?"
+ }
+ return h5URL + sep + "redirect_url=" + url.QueryEscape(redirectURL), nil
+}
+
+func buildWxpayResultURL(returnURL string, req payment.CreatePaymentRequest) (string, error) {
+ u, err := url.Parse(returnURL)
+ if err != nil || !u.IsAbs() || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
+ return "", fmt.Errorf("return URL must be an absolute http(s) URL")
+ }
+
+ values := u.Query()
+ values.Set("out_trade_no", strings.TrimSpace(req.OrderID))
+ if paymentType := strings.TrimSpace(req.PaymentType); paymentType != "" {
+ values.Set("payment_type", paymentType)
+ }
+ if strings.TrimSpace(u.Path) == "" {
+ u.Path = wxpayResultPath
+ }
+ u.RawPath = ""
+ u.RawQuery = values.Encode()
+ u.Fragment = ""
+ return u.String(), nil
+}
+
func wxSV(s *string) string {
if s == nil {
return ""
@@ -219,6 +364,32 @@ func mapWxState(s string) string {
}
}
+func buildWxpayTransactionMetadata(tx *payments.Transaction) map[string]string {
+ if tx == nil {
+ return nil
+ }
+
+ metadata := map[string]string{}
+ if appID := wxSV(tx.Appid); appID != "" {
+ metadata[wxpayMetadataAppID] = appID
+ }
+ if merchantID := wxSV(tx.Mchid); merchantID != "" {
+ metadata[wxpayMetadataMerchantID] = merchantID
+ }
+ if tradeState := wxSV(tx.TradeState); tradeState != "" {
+ metadata[wxpayMetadataTradeState] = tradeState
+ }
+ if tx.Amount != nil {
+ if currency := wxSV(tx.Amount.Currency); currency != "" {
+ metadata[wxpayMetadataCurrency] = currency
+ }
+ }
+ if len(metadata) == 0 {
+ return nil
+ }
+ return metadata
+}
+
func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
c, err := w.ensureClient()
if err != nil {
@@ -243,7 +414,13 @@ func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryO
if tx.SuccessTime != nil {
pa = *tx.SuccessTime
}
- return &payment.QueryOrderResponse{TradeNo: id, Status: mapWxState(wxSV(tx.TradeState)), Amount: amt, PaidAt: pa}, nil
+ return &payment.QueryOrderResponse{
+ TradeNo: id,
+ Status: mapWxState(wxSV(tx.TradeState)),
+ Amount: amt,
+ PaidAt: pa,
+ Metadata: buildWxpayTransactionMetadata(tx),
+ }, nil
}
func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
@@ -275,7 +452,7 @@ func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers
}
return &payment.PaymentNotification{
TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo),
- Amount: amt, Status: st, RawData: rawBody,
+ Amount: amt, Status: st, RawData: rawBody, Metadata: buildWxpayTransactionMetadata(&tx),
}, nil
}
diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go
index b8b99537..e8ac5e54 100644
--- a/backend/internal/payment/provider/wxpay_test.go
+++ b/backend/internal/payment/provider/wxpay_test.go
@@ -3,12 +3,44 @@
package provider
import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/pem"
+ "errors"
+ "net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/wechatpay-apiv3/wechatpay-go/core"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
)
+// generateTestKeyPair returns a fresh RSA 2048 key pair as PEM strings.
+// The wechatpay-go SDK expects PKCS8 private keys and PKIX public keys.
+func generateTestKeyPair(t *testing.T) (privPEM, pubPEM string) {
+ t.Helper()
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ t.Fatalf("generate rsa key: %v", err)
+ }
+ privDER, err := x509.MarshalPKCS8PrivateKey(key)
+ if err != nil {
+ t.Fatalf("marshal pkcs8: %v", err)
+ }
+ pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
+ if err != nil {
+ t.Fatalf("marshal pkix: %v", err)
+ }
+ return string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
+ string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}))
+}
+
func TestMapWxState(t *testing.T) {
t.Parallel()
@@ -96,6 +128,33 @@ func TestWxSV(t *testing.T) {
}
}
+func TestBuildWxpayTransactionMetadata(t *testing.T) {
+ t.Parallel()
+
+ tx := &payments.Transaction{
+ Appid: strPtr("wx-app-id"),
+ Mchid: strPtr("mch-id"),
+ TradeState: strPtr(wxpayTradeStateSuccess),
+ Amount: &payments.TransactionAmount{
+ Currency: strPtr(wxpayCurrency),
+ },
+ }
+
+ metadata := buildWxpayTransactionMetadata(tx)
+ if metadata[wxpayMetadataAppID] != "wx-app-id" {
+ t.Fatalf("appid = %q", metadata[wxpayMetadataAppID])
+ }
+ if metadata[wxpayMetadataMerchantID] != "mch-id" {
+ t.Fatalf("mchid = %q", metadata[wxpayMetadataMerchantID])
+ }
+ if metadata[wxpayMetadataCurrency] != wxpayCurrency {
+ t.Fatalf("currency = %q", metadata[wxpayMetadataCurrency])
+ }
+ if metadata[wxpayMetadataTradeState] != wxpayTradeStateSuccess {
+ t.Fatalf("trade_state = %q", metadata[wxpayMetadataTradeState])
+ }
+}
+
func strPtr(s string) *string {
return &s
}
@@ -149,13 +208,14 @@ func TestFormatPEM(t *testing.T) {
func TestNewWxpay(t *testing.T) {
t.Parallel()
+ privPEM, pubPEM := generateTestKeyPair(t)
validConfig := map[string]string{
"appId": "wx1234567890",
"mchId": "1234567890",
- "privateKey": "fake-private-key",
+ "privateKey": privPEM,
"apiV3Key": "12345678901234567890123456789012", // exactly 32 bytes
- "publicKey": "fake-public-key",
- "publicKeyId": "key-id-001",
+ "publicKey": pubPEM,
+ "publicKeyId": "PUB_KEY_ID_TEST",
"certSerial": "SERIAL001",
}
@@ -206,6 +266,12 @@ func TestNewWxpay(t *testing.T) {
wantErr: true,
errSubstr: "apiV3Key",
},
+ {
+ name: "missing certSerial",
+ config: withOverride(map[string]string{"certSerial": ""}),
+ wantErr: true,
+ errSubstr: "certSerial",
+ },
{
name: "missing publicKey",
config: withOverride(map[string]string{"publicKey": ""}),
@@ -218,17 +284,29 @@ func TestNewWxpay(t *testing.T) {
wantErr: true,
errSubstr: "publicKeyId",
},
+ {
+ name: "malformed privateKey PEM",
+ config: withOverride(map[string]string{"privateKey": "not-a-valid-pem"}),
+ wantErr: true,
+ errSubstr: "WXPAY_CONFIG_INVALID_KEY",
+ },
+ {
+ name: "malformed publicKey PEM",
+ config: withOverride(map[string]string{"publicKey": "not-a-valid-pem"}),
+ wantErr: true,
+ errSubstr: "WXPAY_CONFIG_INVALID_KEY",
+ },
{
name: "apiV3Key too short",
config: withOverride(map[string]string{"apiV3Key": "short"}),
wantErr: true,
- errSubstr: "exactly 32 bytes",
+ errSubstr: "WXPAY_CONFIG_INVALID_KEY_LENGTH",
},
{
name: "apiV3Key too long",
config: withOverride(map[string]string{"apiV3Key": "123456789012345678901234567890123"}), // 33 bytes
wantErr: true,
- errSubstr: "exactly 32 bytes",
+ errSubstr: "WXPAY_CONFIG_INVALID_KEY_LENGTH",
},
}
@@ -257,3 +335,375 @@ func TestNewWxpay(t *testing.T) {
})
}
}
+
+func TestBuildWxpayResultURLPreservesResumeToken(t *testing.T) {
+ t.Parallel()
+
+ resultURL, err := buildWxpayResultURL("https://app.example.com/payment/result?order_id=42&resume_token=resume-42&status=success", payment.CreatePaymentRequest{
+ OrderID: "sub2_42",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("buildWxpayResultURL returned error: %v", err)
+ }
+
+ parsed, err := url.Parse(resultURL)
+ if err != nil {
+ t.Fatalf("url.Parse returned error: %v", err)
+ }
+ query := parsed.Query()
+ if parsed.Path != wxpayResultPath {
+ t.Fatalf("path = %q, want %q", parsed.Path, wxpayResultPath)
+ }
+ if query.Get("resume_token") != "resume-42" {
+ t.Fatalf("resume_token = %q, want %q", query.Get("resume_token"), "resume-42")
+ }
+ if query.Get("order_id") != "42" {
+ t.Fatalf("order_id = %q, want %q", query.Get("order_id"), "42")
+ }
+ if query.Get("out_trade_no") != "sub2_42" {
+ t.Fatalf("out_trade_no = %q, want %q", query.Get("out_trade_no"), "sub2_42")
+ }
+}
+
+func TestResolveWxpayJSAPIAppID(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ config map[string]string
+ want string
+ }{
+ {
+ name: "prefers dedicated mp app id",
+ config: map[string]string{
+ "mpAppId": "wx-mp-app",
+ "appId": "wx-merchant-app",
+ },
+ want: "wx-mp-app",
+ },
+ {
+ name: "falls back to merchant app id",
+ config: map[string]string{
+ "appId": "wx-merchant-app",
+ },
+ want: "wx-merchant-app",
+ },
+ {
+ name: "missing app ids returns empty",
+ config: map[string]string{},
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := ResolveWxpayJSAPIAppID(tt.config); got != tt.want {
+ t.Fatalf("ResolveWxpayJSAPIAppID() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestResolveWxpayCreateMode(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ req payment.CreatePaymentRequest
+ wantMode string
+ wantErr string
+ }{
+ {
+ name: "desktop uses native",
+ req: payment.CreatePaymentRequest{},
+ wantMode: wxpayModeNative,
+ },
+ {
+ name: "mobile uses h5 when client ip is present",
+ req: payment.CreatePaymentRequest{
+ IsMobile: true,
+ ClientIP: "203.0.113.10",
+ },
+ wantMode: wxpayModeH5,
+ },
+ {
+ name: "mobile without client ip returns clear error",
+ req: payment.CreatePaymentRequest{
+ IsMobile: true,
+ },
+ wantErr: "requires client IP",
+ },
+ {
+ name: "openid uses jsapi mode",
+ req: payment.CreatePaymentRequest{
+ OpenID: "openid-123",
+ },
+ wantMode: wxpayModeJSAPI,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, err := resolveWxpayCreateMode(tt.req)
+ if tt.wantErr != "" {
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+ if !strings.Contains(err.Error(), tt.wantErr) {
+ t.Fatalf("error %q should contain %q", err.Error(), tt.wantErr)
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got != tt.wantMode {
+ t.Fatalf("resolveWxpayCreateMode() = %q, want %q", got, tt.wantMode)
+ }
+ })
+ }
+}
+
+func TestCreatePaymentWithOpenIDReturnsJSAPIResult(t *testing.T) {
+ origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
+ origNativePrepay := wxpayNativePrepay
+ origH5Prepay := wxpayH5Prepay
+ t.Cleanup(func() {
+ wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
+ wxpayNativePrepay = origNativePrepay
+ wxpayH5Prepay = origH5Prepay
+ })
+
+ jsapiCalls := 0
+ nativeCalls := 0
+ h5Calls := 0
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ jsapiCalls++
+ if got := wxSV(req.Payer.Openid); got != "openid-123" {
+ t.Fatalf("openid = %q, want %q", got, "openid-123")
+ }
+ if req.SceneInfo == nil || wxSV(req.SceneInfo.PayerClientIp) != "203.0.113.10" {
+ t.Fatalf("scene_info payer_client_ip = %q, want %q", wxSV(req.SceneInfo.PayerClientIp), "203.0.113.10")
+ }
+ return &jsapi.PrepayWithRequestPaymentResponse{
+ Appid: core.String("wx123"),
+ TimeStamp: core.String("1712345678"),
+ NonceStr: core.String("nonce-123"),
+ Package: core.String("prepay_id=wx_prepay_123"),
+ SignType: core.String("RSA"),
+ PaySign: core.String("signed-payload"),
+ }, nil, nil
+ }
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ nativeCalls++
+ return &native.PrepayResponse{}, nil, nil
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ h5Calls++
+ return &h5.PrepayResponse{}, nil, nil
+ }
+
+ provider := &Wxpay{
+ config: map[string]string{
+ "appId": "wx123",
+ "mchId": "mch123",
+ },
+ coreClient: &core.Client{},
+ }
+
+ resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
+ OrderID: "sub2_88",
+ Amount: "66.88",
+ PaymentType: payment.TypeWxpay,
+ NotifyURL: "https://merchant.example/payment/notify",
+ OpenID: "openid-123",
+ ClientIP: "203.0.113.10",
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if jsapiCalls != 1 {
+ t.Fatalf("jsapi prepay calls = %d, want 1", jsapiCalls)
+ }
+ if nativeCalls != 0 {
+ t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
+ }
+ if h5Calls != 0 {
+ t.Fatalf("h5 prepay calls = %d, want 0", h5Calls)
+ }
+ if resp.ResultType != payment.CreatePaymentResultJSAPIReady {
+ t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady)
+ }
+ if resp.JSAPI == nil {
+ t.Fatal("expected jsapi payload, got nil")
+ }
+ if resp.JSAPI.AppID != "wx123" {
+ t.Fatalf("jsapi appId = %q, want %q", resp.JSAPI.AppID, "wx123")
+ }
+ if resp.JSAPI.TimeStamp != "1712345678" {
+ t.Fatalf("jsapi timeStamp = %q, want %q", resp.JSAPI.TimeStamp, "1712345678")
+ }
+ if resp.JSAPI.NonceStr != "nonce-123" {
+ t.Fatalf("jsapi nonceStr = %q, want %q", resp.JSAPI.NonceStr, "nonce-123")
+ }
+ if resp.JSAPI.Package != "prepay_id=wx_prepay_123" {
+ t.Fatalf("jsapi package = %q, want %q", resp.JSAPI.Package, "prepay_id=wx_prepay_123")
+ }
+ if resp.JSAPI.SignType != "RSA" {
+ t.Fatalf("jsapi signType = %q, want %q", resp.JSAPI.SignType, "RSA")
+ }
+ if resp.JSAPI.PaySign != "signed-payload" {
+ t.Fatalf("jsapi paySign = %q, want %q", resp.JSAPI.PaySign, "signed-payload")
+ }
+}
+
+func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) {
+ origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
+ origNativePrepay := wxpayNativePrepay
+ origH5Prepay := wxpayH5Prepay
+ t.Cleanup(func() {
+ wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
+ wxpayNativePrepay = origNativePrepay
+ wxpayH5Prepay = origH5Prepay
+ })
+
+ jsapiCalls := 0
+ nativeCalls := 0
+ h5Calls := 0
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ jsapiCalls++
+ return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil
+ }
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ nativeCalls++
+ return &native.PrepayResponse{}, nil, nil
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ h5Calls++
+ if req.SceneInfo == nil {
+ t.Fatal("expected scene_info, got nil")
+ }
+ if got := wxSV(req.SceneInfo.PayerClientIp); got != "203.0.113.10" {
+ t.Fatalf("scene_info payer_client_ip = %q, want %q", got, "203.0.113.10")
+ }
+ if req.SceneInfo.H5Info == nil {
+ t.Fatal("expected scene_info.h5_info, got nil")
+ }
+ if got := wxSV(req.SceneInfo.H5Info.Type); got != wxpayH5Type {
+ t.Fatalf("scene_info.h5_info.type = %q, want %q", got, wxpayH5Type)
+ }
+ if got := wxSV(req.SceneInfo.H5Info.AppName); got != "Sub2API" {
+ t.Fatalf("scene_info.h5_info.app_name = %q, want %q", got, "Sub2API")
+ }
+ if got := wxSV(req.SceneInfo.H5Info.AppUrl); got != "https://app.example.com" {
+ t.Fatalf("scene_info.h5_info.app_url = %q, want %q", got, "https://app.example.com")
+ }
+ return &h5.PrepayResponse{
+ H5Url: core.String("https://wx.tenpay.example/h5pay?prepay_id=1"),
+ }, nil, nil
+ }
+
+ provider := &Wxpay{
+ config: map[string]string{
+ "appId": "wx123",
+ "mchId": "mch123",
+ "h5AppName": "Sub2API",
+ "h5AppUrl": "https://app.example.com",
+ },
+ coreClient: &core.Client{},
+ }
+
+ resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
+ OrderID: "sub2_99",
+ Amount: "66.88",
+ PaymentType: payment.TypeWxpay,
+ Subject: "Balance Recharge",
+ NotifyURL: "https://merchant.example/payment/notify",
+ ReturnURL: "https://merchant.example/payment/result?resume_token=resume-99",
+ ClientIP: "203.0.113.10",
+ IsMobile: true,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if jsapiCalls != 0 {
+ t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls)
+ }
+ if nativeCalls != 0 {
+ t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
+ }
+ if h5Calls != 1 {
+ t.Fatalf("h5 prepay calls = %d, want 1", h5Calls)
+ }
+ if !strings.Contains(resp.PayURL, "redirect_url=") {
+ t.Fatalf("pay_url = %q, want redirect_url query appended", resp.PayURL)
+ }
+}
+
+func TestCreatePaymentMobileH5ReturnsNoAuthErrorWithoutNativeFallback(t *testing.T) {
+ origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
+ origNativePrepay := wxpayNativePrepay
+ origH5Prepay := wxpayH5Prepay
+ t.Cleanup(func() {
+ wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
+ wxpayNativePrepay = origNativePrepay
+ wxpayH5Prepay = origH5Prepay
+ })
+
+ jsapiCalls := 0
+ nativeCalls := 0
+ h5Calls := 0
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ jsapiCalls++
+ return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ h5Calls++
+ return nil, nil, errors.New("NO_AUTH")
+ }
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ nativeCalls++
+ return &native.PrepayResponse{
+ CodeUrl: core.String("weixin://wxpay/bizpayurl?pr=fallback-native"),
+ }, nil, nil
+ }
+
+ provider := &Wxpay{
+ config: map[string]string{
+ "appId": "wx123",
+ "mchId": "mch123",
+ },
+ coreClient: &core.Client{},
+ }
+
+ resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
+ OrderID: "sub2_100",
+ Amount: "66.88",
+ PaymentType: payment.TypeWxpay,
+ Subject: "Balance Recharge",
+ NotifyURL: "https://merchant.example/payment/notify",
+ ClientIP: "203.0.113.10",
+ IsMobile: true,
+ })
+ if err == nil {
+ t.Fatal("expected no-auth error, got nil")
+ }
+ if jsapiCalls != 0 {
+ t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls)
+ }
+ if h5Calls != 1 {
+ t.Fatalf("h5 prepay calls = %d, want 1", h5Calls)
+ }
+ if nativeCalls != 0 {
+ t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
+ }
+ if resp != nil {
+ t.Fatalf("expected nil response, got %+v", resp)
+ }
+ if !strings.Contains(err.Error(), "NO_AUTH") {
+ t.Fatalf("error = %v, want NO_AUTH", err)
+ }
+}
diff --git a/backend/internal/payment/types.go b/backend/internal/payment/types.go
index 5d613a4a..e7ac6727 100644
--- a/backend/internal/payment/types.go
+++ b/backend/internal/payment/types.go
@@ -101,34 +101,69 @@ type CreatePaymentRequest struct {
Subject string // Product description
NotifyURL string // Webhook callback URL
ReturnURL string // Browser redirect URL after payment
+ OpenID string // WeChat JSAPI payer OpenID when available
ClientIP string // Payer's IP address
IsMobile bool // Whether the request comes from a mobile device
InstanceSubMethods string // Comma-separated sub-methods from instance supported_types (for Stripe)
}
+// CreatePaymentResultType describes the shape of the create-payment result.
+type CreatePaymentResultType = string
+
+const (
+ CreatePaymentResultOrderCreated CreatePaymentResultType = "order_created"
+ CreatePaymentResultOAuthRequired CreatePaymentResultType = "oauth_required"
+ CreatePaymentResultJSAPIReady CreatePaymentResultType = "jsapi_ready"
+)
+
+// WechatOAuthInfo describes the next step when WeChat OAuth is required before payment.
+type WechatOAuthInfo struct {
+ AuthorizeURL string `json:"authorize_url,omitempty"`
+ AppID string `json:"appid,omitempty"`
+ OpenID string `json:"openid,omitempty"`
+ Scope string `json:"scope,omitempty"`
+ State string `json:"state,omitempty"`
+ RedirectURL string `json:"redirect_url,omitempty"`
+}
+
+// WechatJSAPIPayload contains the fields the frontend needs to invoke WeChat JSAPI payment.
+type WechatJSAPIPayload struct {
+ AppID string `json:"appId,omitempty"`
+ TimeStamp string `json:"timeStamp,omitempty"`
+ NonceStr string `json:"nonceStr,omitempty"`
+ Package string `json:"package,omitempty"`
+ SignType string `json:"signType,omitempty"`
+ PaySign string `json:"paySign,omitempty"`
+}
+
// CreatePaymentResponse is returned after successfully initiating a payment.
type CreatePaymentResponse struct {
- TradeNo string // Third-party transaction ID
- PayURL string // H5 payment URL (alipay/wxpay)
- QRCode string // QR code content for scanning
- ClientSecret string // Stripe PaymentIntent client secret
+ TradeNo string // Third-party transaction ID
+ PayURL string // H5 payment URL (alipay/wxpay)
+ QRCode string // QR code content for scanning
+ ClientSecret string // Stripe PaymentIntent client secret
+ ResultType CreatePaymentResultType // Typed result contract for frontend flows
+ OAuth *WechatOAuthInfo // WeChat OAuth bootstrap payload when required
+ JSAPI *WechatJSAPIPayload // WeChat JSAPI invocation payload when ready
}
// QueryOrderResponse describes the payment status from the upstream provider.
type QueryOrderResponse struct {
- TradeNo string
- Status string // "pending", "paid", "failed", "refunded"
- Amount float64 // Amount in CNY
- PaidAt string // RFC3339 timestamp or empty
+ TradeNo string
+ Status string // "pending", "paid", "failed", "refunded"
+ Amount float64 // Amount in CNY
+ PaidAt string // RFC3339 timestamp or empty
+ Metadata map[string]string
}
// PaymentNotification is the parsed result of a webhook/notify callback.
type PaymentNotification struct {
- TradeNo string
- OrderID string
- Amount float64
- Status string // "success" or "failed"
- RawData string // Raw notification body for audit
+ TradeNo string
+ OrderID string
+ Amount float64
+ Status string // "success" or "failed"
+ RawData string // Raw notification body for audit
+ Metadata map[string]string
}
// RefundRequest contains the parameters for requesting a refund.
@@ -179,3 +214,9 @@ type CancelableProvider interface {
// CancelPayment cancels/expires a pending payment on the upstream platform.
CancelPayment(ctx context.Context, tradeNo string) error
}
+
+// MerchantIdentityProvider exposes the current non-sensitive merchant identity
+// derived from provider configuration for snapshot consistency checks.
+type MerchantIdentityProvider interface {
+ MerchantIdentityMetadata() map[string]string
+}
diff --git a/backend/internal/payment/wire.go b/backend/internal/payment/wire.go
index 9717465d..4b7f422d 100644
--- a/backend/internal/payment/wire.go
+++ b/backend/internal/payment/wire.go
@@ -4,6 +4,7 @@ import (
"encoding/hex"
"fmt"
"log/slog"
+ "strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -19,11 +20,22 @@ type EncryptionKey []byte
// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
// to prevent startup with a misconfigured encryption key.
func ProvideEncryptionKey(cfg *config.Config) (EncryptionKey, error) {
- if cfg.Totp.EncryptionKey == "" {
+ if cfg == nil {
+ slog.Warn("payment encryption key not configured — encrypted payment config and resume signing will be unavailable")
+ return nil, nil
+ }
+ keyHex := strings.TrimSpace(cfg.Totp.EncryptionKey)
+ if keyHex == "" {
slog.Warn("payment encryption key not configured — encrypted payment config will be unavailable")
return nil, nil
}
- key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
+ // Reject auto-generated TOTP keys for payment signing.
+ // They change across restarts/instances and can silently break resume-token flows.
+ if !cfg.Totp.EncryptionKeyConfigured {
+ slog.Warn("payment encryption/signing key is not explicitly configured; set TOTP_ENCRYPTION_KEY to enable payment resume tokens")
+ return nil, nil
+ }
+ key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("invalid payment encryption key (hex decode): %w", err)
}
diff --git a/backend/internal/payment/wire_test.go b/backend/internal/payment/wire_test.go
new file mode 100644
index 00000000..1b360f89
--- /dev/null
+++ b/backend/internal/payment/wire_test.go
@@ -0,0 +1,62 @@
+package payment
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+)
+
+func TestProvideEncryptionKeySkipsAutoGeneratedTotpKey(t *testing.T) {
+ t.Parallel()
+
+ cfg := &config.Config{
+ Totp: config.TotpConfig{
+ EncryptionKey: strings.Repeat("a", 64),
+ EncryptionKeyConfigured: false,
+ },
+ }
+
+ key, err := ProvideEncryptionKey(cfg)
+ if err != nil {
+ t.Fatalf("ProvideEncryptionKey returned error: %v", err)
+ }
+ if len(key) != 0 {
+ t.Fatalf("encryption key len = %d, want 0", len(key))
+ }
+}
+
+func TestProvideEncryptionKeyUsesConfiguredTotpKey(t *testing.T) {
+ t.Parallel()
+
+ cfg := &config.Config{
+ Totp: config.TotpConfig{
+ EncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
+ EncryptionKeyConfigured: true,
+ },
+ }
+
+ key, err := ProvideEncryptionKey(cfg)
+ if err != nil {
+ t.Fatalf("ProvideEncryptionKey returned error: %v", err)
+ }
+ if len(key) != 32 {
+ t.Fatalf("encryption key len = %d, want 32", len(key))
+ }
+}
+
+func TestProvideEncryptionKeyRejectsConfiguredInvalidLength(t *testing.T) {
+ t.Parallel()
+
+ cfg := &config.Config{
+ Totp: config.TotpConfig{
+ EncryptionKey: "abcd",
+ EncryptionKeyConfigured: true,
+ },
+ }
+
+ _, err := ProvideEncryptionKey(cfg)
+ if err == nil {
+ t.Fatal("expected error for invalid key length")
+ }
+}
diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go
index 095305c2..e8b25c2b 100644
--- a/backend/internal/pkg/apicompat/anthropic_responses_test.go
+++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go
@@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) {
assert.Equal(t, 5, anth.Usage.OutputTokens)
}
+func TestResponsesToAnthropic_CachedTokensUseAnthropicInputSemantics(t *testing.T) {
+ resp := &ResponsesResponse{
+ ID: "resp_cached",
+ Model: "gpt-5.2",
+ Status: "completed",
+ Output: []ResponsesOutput{
+ {
+ Type: "message",
+ Content: []ResponsesContentPart{
+ {Type: "output_text", Text: "Cached response"},
+ },
+ },
+ },
+ Usage: &ResponsesUsage{
+ InputTokens: 54006,
+ OutputTokens: 123,
+ TotalTokens: 54129,
+ InputTokensDetails: &ResponsesInputTokensDetails{
+ CachedTokens: 50688,
+ },
+ },
+ }
+
+ anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
+ assert.Equal(t, 3318, anth.Usage.InputTokens)
+ assert.Equal(t, 50688, anth.Usage.CacheReadInputTokens)
+ assert.Equal(t, 123, anth.Usage.OutputTokens)
+}
+
+func TestResponsesToAnthropic_CachedTokensClampInputTokens(t *testing.T) {
+ resp := &ResponsesResponse{
+ ID: "resp_cached_clamp",
+ Model: "gpt-5.2",
+ Status: "completed",
+ Usage: &ResponsesUsage{
+ InputTokens: 100,
+ OutputTokens: 5,
+ InputTokensDetails: &ResponsesInputTokensDetails{
+ CachedTokens: 150,
+ },
+ },
+ }
+
+ anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
+ assert.Equal(t, 0, anth.Usage.InputTokens)
+ assert.Equal(t, 150, anth.Usage.CacheReadInputTokens)
+ assert.Equal(t, 5, anth.Usage.OutputTokens)
+}
+
func TestResponsesToAnthropic_ToolUse(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_456",
@@ -209,6 +258,48 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) {
assert.Equal(t, "tool_use", anth.Content[1].Type)
assert.Equal(t, "call_1", anth.Content[1].ID)
assert.Equal(t, "get_weather", anth.Content[1].Name)
+ assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input))
+}
+
+func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) {
+ resp := &ResponsesResponse{
+ ID: "resp_read",
+ Model: "gpt-5.5",
+ Status: "completed",
+ Output: []ResponsesOutput{
+ {
+ Type: "function_call",
+ CallID: "call_read",
+ Name: "Read",
+ Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
+ },
+ },
+ }
+
+ anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
+ require.Len(t, anth.Content, 1)
+ assert.Equal(t, "tool_use", anth.Content[0].Type)
+ assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, string(anth.Content[0].Input))
+}
+
+func TestResponsesToAnthropic_PreservesEmptyStringsForOtherTools(t *testing.T) {
+ resp := &ResponsesResponse{
+ ID: "resp_other",
+ Model: "gpt-5.5",
+ Status: "completed",
+ Output: []ResponsesOutput{
+ {
+ Type: "function_call",
+ CallID: "call_other",
+ Name: "Search",
+ Arguments: `{"query":""}`,
+ },
+ },
+ }
+
+ anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
+ require.Len(t, anth.Content, 1)
+ assert.JSONEq(t, `{"query":""}`, string(anth.Content[0].Input))
}
func TestResponsesToAnthropic_Reasoning(t *testing.T) {
@@ -343,6 +434,36 @@ func TestStreamingTextOnly(t *testing.T) {
assert.Equal(t, "message_stop", events[1].Type)
}
+func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
+ state := NewResponsesEventToAnthropicState()
+ ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.created",
+ Response: &ResponsesResponse{ID: "resp_cached_stream", Model: "gpt-5.2"},
+ }, state)
+
+ events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.completed",
+ Response: &ResponsesResponse{
+ Status: "completed",
+ Usage: &ResponsesUsage{
+ InputTokens: 54006,
+ OutputTokens: 123,
+ TotalTokens: 54129,
+ InputTokensDetails: &ResponsesInputTokensDetails{
+ CachedTokens: 50688,
+ },
+ },
+ },
+ }, state)
+
+ require.Len(t, events, 2)
+ assert.Equal(t, "message_delta", events[0].Type)
+ assert.Equal(t, 3318, events[0].Usage.InputTokens)
+ assert.Equal(t, 50688, events[0].Usage.CacheReadInputTokens)
+ assert.Equal(t, 123, events[0].Usage.OutputTokens)
+ assert.Equal(t, "message_stop", events[1].Type)
+}
+
func TestStreamingToolCall(t *testing.T) {
state := NewResponsesEventToAnthropicState()
@@ -393,6 +514,41 @@ func TestStreamingToolCall(t *testing.T) {
assert.Equal(t, "tool_use", events[0].Delta.StopReason)
}
+func TestStreamingReadToolDropsEmptyPages(t *testing.T) {
+ state := NewResponsesEventToAnthropicState()
+
+ ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.created",
+ Response: &ResponsesResponse{ID: "resp_read_stream", Model: "gpt-5.5"},
+ }, state)
+
+ events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.output_item.added",
+ OutputIndex: 0,
+ Item: &ResponsesOutput{Type: "function_call", CallID: "call_read", Name: "Read"},
+ }, state)
+ require.Len(t, events, 1)
+ assert.Equal(t, "content_block_start", events[0].Type)
+
+ events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.function_call_arguments.delta",
+ OutputIndex: 0,
+ Delta: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
+ }, state)
+ assert.Len(t, events, 0)
+
+ events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.function_call_arguments.done",
+ OutputIndex: 0,
+ Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
+ }, state)
+ require.Len(t, events, 2)
+ assert.Equal(t, "content_block_delta", events[0].Type)
+ assert.Equal(t, "input_json_delta", events[0].Delta.Type)
+ assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, events[0].Delta.PartialJSON)
+ assert.Equal(t, "content_block_stop", events[1].Type)
+}
+
func TestStreamingReasoning(t *testing.T) {
state := NewResponsesEventToAnthropicState()
@@ -835,9 +991,40 @@ func TestAnthropicToResponses_ToolChoiceSpecific(t *testing.T) {
var tc map[string]any
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "function", tc["type"])
- fn, ok := tc["function"].(map[string]any)
- require.True(t, ok)
- assert.Equal(t, "get_weather", fn["name"])
+ assert.Equal(t, "get_weather", tc["name"])
+ assert.NotContains(t, tc, "function")
+}
+
+func TestResponsesToAnthropicRequest_ToolChoiceFunctionName(t *testing.T) {
+ req := &ResponsesRequest{
+ Model: "gpt-5.2",
+ Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
+ ToolChoice: json.RawMessage(`{"type":"function","name":"get_weather"}`),
+ }
+
+ resp, err := ResponsesToAnthropicRequest(req)
+ require.NoError(t, err)
+
+ var tc map[string]string
+ require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
+ assert.Equal(t, "tool", tc["type"])
+ assert.Equal(t, "get_weather", tc["name"])
+}
+
+func TestResponsesToAnthropicRequest_ToolChoiceLegacyFunctionName(t *testing.T) {
+ req := &ResponsesRequest{
+ Model: "gpt-5.2",
+ Input: json.RawMessage(`[{"role":"user","content":"Hello"}]`),
+ ToolChoice: json.RawMessage(`{"type":"function","function":{"name":"get_weather"}}`),
+ }
+
+ resp, err := ResponsesToAnthropicRequest(req)
+ require.NoError(t, err)
+
+ var tc map[string]string
+ require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
+ assert.Equal(t, "tool", tc["type"])
+ assert.Equal(t, "get_weather", tc["name"])
}
// ---------------------------------------------------------------------------
diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses.go b/backend/internal/pkg/apicompat/anthropic_to_responses.go
index 485262e8..268f9f22 100644
--- a/backend/internal/pkg/apicompat/anthropic_to_responses.go
+++ b/backend/internal/pkg/apicompat/anthropic_to_responses.go
@@ -75,7 +75,7 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
// {"type":"auto"} → "auto"
// {"type":"any"} → "required"
// {"type":"none"} → "none"
-// {"type":"tool","name":"X"} → {"type":"function","function":{"name":"X"}}
+// {"type":"tool","name":"X"} → {"type":"function","name":"X"}
func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage, error) {
var tc struct {
Type string `json:"type"`
@@ -94,8 +94,8 @@ func convertAnthropicToolChoiceToResponses(raw json.RawMessage) (json.RawMessage
return json.Marshal("none")
case "tool":
return json.Marshal(map[string]any{
- "type": "function",
- "function": map[string]string{"name": tc.Name},
+ "type": "function",
+ "name": tc.Name,
})
default:
// Pass through unknown types as-is
diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go
index c140449a..35d42999 100644
--- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go
+++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go
@@ -281,6 +281,8 @@ func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
var tc map[string]any
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "function", tc["type"])
+ assert.Equal(t, "get_weather", tc["name"])
+ assert.NotContains(t, tc, "function")
}
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
index c2725406..64ef5781 100644
--- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
+++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
@@ -420,7 +420,7 @@ func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []R
//
// "auto" → "auto"
// "none" → "none"
-// {"name":"X"} → {"type":"function","function":{"name":"X"}}
+// {"name":"X"} → {"type":"function","name":"X"}
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try string first ("auto", "none", etc.) — pass through as-is.
var s string
@@ -436,7 +436,7 @@ func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage,
return nil, err
}
return json.Marshal(map[string]any{
- "type": "function",
- "function": map[string]string{"name": obj.Name},
+ "type": "function",
+ "name": obj.Name,
})
}
diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic.go b/backend/internal/pkg/apicompat/responses_to_anthropic.go
index 5409a0f4..489ed238 100644
--- a/backend/internal/pkg/apicompat/responses_to_anthropic.go
+++ b/backend/internal/pkg/apicompat/responses_to_anthropic.go
@@ -52,7 +52,7 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
Type: "tool_use",
ID: fromResponsesCallID(item.CallID),
Name: item.Name,
- Input: json.RawMessage(item.Arguments),
+ Input: sanitizeAnthropicToolUseInput(item.Name, item.Arguments),
})
case "web_search_call":
toolUseID := "srvtoolu_" + item.ID
@@ -84,18 +84,34 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
if resp.Usage != nil {
- out.Usage = AnthropicUsage{
- InputTokens: resp.Usage.InputTokens,
- OutputTokens: resp.Usage.OutputTokens,
- }
- if resp.Usage.InputTokensDetails != nil {
- out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
- }
+ out.Usage = anthropicUsageFromResponsesUsage(resp.Usage)
}
return out
}
+func anthropicUsageFromResponsesUsage(usage *ResponsesUsage) AnthropicUsage {
+ if usage == nil {
+ return AnthropicUsage{}
+ }
+
+ cachedTokens := 0
+ if usage.InputTokensDetails != nil {
+ cachedTokens = usage.InputTokensDetails.CachedTokens
+ }
+
+ inputTokens := usage.InputTokens - cachedTokens
+ if inputTokens < 0 {
+ inputTokens = 0
+ }
+
+ return AnthropicUsage{
+ InputTokens: inputTokens,
+ OutputTokens: usage.OutputTokens,
+ CacheReadInputTokens: cachedTokens,
+ }
+}
+
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
switch status {
case "incomplete":
@@ -113,6 +129,28 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
}
}
+func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage {
+ if name != "Read" || raw == "" {
+ return json.RawMessage(raw)
+ }
+
+ var input map[string]json.RawMessage
+ if err := json.Unmarshal([]byte(raw), &input); err != nil {
+ return json.RawMessage(raw)
+ }
+
+ if pages, ok := input["pages"]; !ok || string(pages) != `""` {
+ return json.RawMessage(raw)
+ }
+
+ delete(input, "pages")
+ sanitized, err := json.Marshal(input)
+ if err != nil {
+ return json.RawMessage(raw)
+ }
+ return sanitized
+}
+
// ---------------------------------------------------------------------------
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
// ---------------------------------------------------------------------------
@@ -126,6 +164,8 @@ type ResponsesEventToAnthropicState struct {
ContentBlockIndex int
ContentBlockOpen bool
CurrentBlockType string // "text" | "thinking" | "tool_use"
+ CurrentToolName string
+ CurrentToolArgs string
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
OutputIndexToBlockIdx map[int]int
@@ -165,7 +205,7 @@ func ResponsesEventToAnthropicEvents(
case "response.function_call_arguments.delta":
return resToAnthHandleFuncArgsDelta(evt, state)
case "response.function_call_arguments.done":
- return resToAnthHandleBlockDone(state)
+ return resToAnthHandleFuncArgsDone(evt, state)
case "response.output_item.done":
return resToAnthHandleOutputItemDone(evt, state)
case "response.reasoning_summary_text.delta":
@@ -262,6 +302,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
state.ContentBlockOpen = true
state.CurrentBlockType = "tool_use"
+ state.CurrentToolName = evt.Item.Name
+ state.CurrentToolArgs = ""
events = append(events, AnthropicStreamEvent{
Type: "content_block_start",
@@ -342,6 +384,11 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
return nil
}
+ if state.CurrentBlockType == "tool_use" && state.CurrentToolName == "Read" {
+ state.CurrentToolArgs += evt.Delta
+ return nil
+ }
+
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
if !ok {
return nil
@@ -357,6 +404,33 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
}}
}
+func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
+ if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" {
+ return resToAnthHandleBlockDone(state)
+ }
+
+ raw := evt.Arguments
+ if raw == "" {
+ raw = state.CurrentToolArgs
+ }
+ sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw)
+ if len(sanitized) == 0 {
+ return closeCurrentBlock(state)
+ }
+
+ idx := state.ContentBlockIndex
+ events := []AnthropicStreamEvent{{
+ Type: "content_block_delta",
+ Index: &idx,
+ Delta: &AnthropicDelta{
+ Type: "input_json_delta",
+ PartialJSON: string(sanitized),
+ },
+ }}
+ events = append(events, closeCurrentBlock(state)...)
+ return events
+}
+
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if evt.Delta == "" {
return nil
@@ -466,11 +540,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
stopReason := "end_turn"
if evt.Response != nil {
if evt.Response.Usage != nil {
- state.InputTokens = evt.Response.Usage.InputTokens
- state.OutputTokens = evt.Response.Usage.OutputTokens
- if evt.Response.Usage.InputTokensDetails != nil {
- state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
- }
+ usage := anthropicUsageFromResponsesUsage(evt.Response.Usage)
+ state.InputTokens = usage.InputTokens
+ state.OutputTokens = usage.OutputTokens
+ state.CacheReadInputTokens = usage.CacheReadInputTokens
}
switch evt.Response.Status {
case "incomplete":
@@ -509,6 +582,8 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE
idx := state.ContentBlockIndex
state.ContentBlockOpen = false
state.ContentBlockIndex++
+ state.CurrentToolName = ""
+ state.CurrentToolArgs = ""
return []AnthropicStreamEvent{{
Type: "content_block_stop",
Index: &idx,
diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go
index f0a5b07e..8fa652f2 100644
--- a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go
+++ b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go
@@ -390,7 +390,7 @@ func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool {
var out []AnthropicTool
for _, t := range tools {
switch t.Type {
- case "web_search":
+ case "web_search", "google_search", "web_search_20250305":
out = append(out, AnthropicTool{
Type: "web_search_20250305",
Name: "web_search",
@@ -428,7 +428,8 @@ func normalizeAnthropicInputSchema(schema json.RawMessage) json.RawMessage {
// "auto" → {"type":"auto"}
// "required" → {"type":"any"}
// "none" → {"type":"none"}
-// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"}
+// {"type":"function","name":"X"} → {"type":"tool","name":"X"}
+// {"type":"function","function":{"name":"X"}} → {"type":"tool","name":"X"} // legacy
func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try as string first
var s string
@@ -448,14 +449,22 @@ func convertResponsesToAnthropicToolChoice(raw json.RawMessage) (json.RawMessage
// Try as object with type=function
var tc struct {
Type string `json:"type"`
+ Name string `json:"name"`
Function struct {
Name string `json:"name"`
} `json:"function"`
}
- if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" && tc.Function.Name != "" {
+ if err := json.Unmarshal(raw, &tc); err == nil && tc.Type == "function" {
+ name := strings.TrimSpace(tc.Name)
+ if name == "" {
+ name = strings.TrimSpace(tc.Function.Name)
+ }
+ if name == "" {
+ return raw, nil
+ }
return json.Marshal(map[string]string{
"type": "tool",
- "name": tc.Function.Name,
+ "name": name,
})
}
diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go
index e0d1a53e..f8c6b75f 100644
--- a/backend/internal/pkg/apicompat/types.go
+++ b/backend/internal/pkg/apicompat/types.go
@@ -12,17 +12,23 @@ import "encoding/json"
// AnthropicRequest is the request body for POST /v1/messages.
type AnthropicRequest struct {
- Model string `json:"model"`
- MaxTokens int `json:"max_tokens"`
- System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
- Messages []AnthropicMessage `json:"messages"`
- Tools []AnthropicTool `json:"tools,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopP *float64 `json:"top_p,omitempty"`
- StopSeqs []string `json:"stop_sequences,omitempty"`
- Thinking *AnthropicThinking `json:"thinking,omitempty"`
- ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
+ Model string `json:"model"`
+ MaxTokens int `json:"max_tokens"`
+ System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
+ Messages []AnthropicMessage `json:"messages"`
+ Tools []AnthropicTool `json:"tools,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP *float64 `json:"top_p,omitempty"`
+ StopSeqs []string `json:"stop_sequences,omitempty"`
+ Thinking *AnthropicThinking `json:"thinking,omitempty"`
+ ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
+ // Metadata 会被原样透传给上游。OAuth/Claude-Code 路径依赖 metadata.user_id
+ // 参与上游的"是否为官方 Claude Code 请求"判定;如果经由本结构体重新序列化
+ // 时丢弃该字段,网关侧后续的 metadata 重写(ensureClaudeOAuthMetadataUserID/
+ // RewriteUserIDWithMasking) 在 body 里拿不到起点,就无法重建一个合法的
+ // user_id,进而导致请求被归类为第三方 app。
+ Metadata json.RawMessage `json:"metadata,omitempty"`
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
}
@@ -76,10 +82,18 @@ type AnthropicImageSource struct {
// AnthropicTool describes a tool available to the model.
type AnthropicTool struct {
- Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
- Name string `json:"name"`
- Description string `json:"description,omitempty"`
- InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
+ Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
+ CacheControl *AnthropicCacheControl `json:"cache_control,omitempty"`
+}
+
+// AnthropicCacheControl 对应 Anthropic API 的 cache_control 字段。
+// ttl 默认由调用方决定;本项目策略见 claude.DefaultCacheControlTTL。
+type AnthropicCacheControl struct {
+ Type string `json:"type"` // "ephemeral"
+ TTL string `json:"ttl,omitempty"` // "5m" / "1h" / 省略=默认 5m(由 Anthropic 判定)
}
// AnthropicResponse is the non-streaming response from POST /v1/messages.
diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go
index 21c723d2..aa59ba64 100644
--- a/backend/internal/pkg/claude/constants.go
+++ b/backend/internal/pkg/claude/constants.go
@@ -4,6 +4,12 @@ package claude
// Claude Code 客户端相关常量
// Beta header 常量
+//
+// 这里的常量对齐真实 Claude Code CLI 的最新流量(截至 2026-04)。
+// 选型参考:与 Parrot (src/transform/cc_mimicry.py) 的 BETAS 保持一致,
+// 原因:Anthropic 上游会基于 anthropic-beta 的完整集合判定请求来源;
+// 缺少任何"官方 Claude Code 请求才会带"的 beta,都会被降级到第三方额度,
+// 对应报错:`Third-party apps now draw from your extra usage, not your plan limits.`
const (
BetaOAuth = "oauth-2025-04-20"
BetaClaudeCode = "claude-code-20250219"
@@ -12,6 +18,13 @@ const (
BetaTokenCounting = "token-counting-2024-11-01"
BetaContext1M = "context-1m-2025-08-07"
BetaFastMode = "fast-mode-2026-02-01"
+
+ // 新增(对齐官方 CLI 2.1.9x 以来的流量)
+ BetaPromptCachingScope = "prompt-caching-scope-2026-01-05"
+ BetaEffort = "effort-2025-11-24"
+ BetaRedactThinking = "redact-thinking-2026-02-12"
+ BetaContextManagement = "context-management-2025-06-27"
+ BetaExtendedCacheTTL = "extended-cache-ttl-2025-04-11"
)
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
@@ -44,11 +57,43 @@ const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," +
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
+// DefaultCacheControlTTL 是网关代理为自己生成的 cache_control 块默认使用的 ttl。
+// 真实 Claude Code CLI 当前使用 "1h",但本仓策略是"客户端透传 ttl 优先;
+// 客户端缺省时统一使用 5m",这样既不浪费 1h 缓存额度,也保留客户端自定义能力。
+const DefaultCacheControlTTL = "5m"
+
+// CLICurrentVersion 是 sub2api 当前对外伪装的 Claude Code CLI 版本号(三段 semver)。
+// 用于 billing attribution block 中的 cc_version=X.Y.Z.{fp} 前缀以及 fingerprint 计算。
+// 必须与 DefaultHeaders["User-Agent"] 中的版本号严格一致;不一致会被 Anthropic 判第三方。
+const CLICurrentVersion = "2.1.92"
+
+// FullClaudeCodeMimicryBetas 返回最"像"真实 Claude Code CLI 的完整 beta 列表,
+// 用于 OAuth 账号伪装成 Claude Code 时使用。
+// 顺序与真实 CLI 抓包一致。
+//
+// 使用建议:
+// - OAuth 账号 + 非 haiku:追加这整份列表,再按需保留 client 带来的 beta。
+// - OAuth 账号 + haiku:Anthropic 对 haiku 不做 third-party 判定,使用 HaikuBetaHeader 即可。
+// - API-key 账号:不要使用本函数,参见 APIKeyBetaHeader。
+func FullClaudeCodeMimicryBetas() []string {
+ return []string{
+ BetaClaudeCode,
+ BetaOAuth,
+ BetaInterleavedThinking,
+ BetaPromptCachingScope,
+ BetaEffort,
+ BetaRedactThinking,
+ BetaContextManagement,
+ BetaExtendedCacheTTL,
+ }
+}
+
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{
// Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
- "User-Agent": "claude-cli/2.1.22 (external, cli)",
+ // 版本参考:对齐 Parrot (src/transform/cc_mimicry.py:49) 的 CLI_USER_AGENT。
+ "User-Agent": "claude-cli/2.1.92 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux",
diff --git a/backend/internal/pkg/httputil/body.go b/backend/internal/pkg/httputil/body.go
index 69e99dc5..cee12948 100644
--- a/backend/internal/pkg/httputil/body.go
+++ b/backend/internal/pkg/httputil/body.go
@@ -2,16 +2,28 @@ package httputil
import (
"bytes"
+ "compress/gzip"
+ "compress/zlib"
+ "errors"
+ "fmt"
"io"
"net/http"
+ "strings"
+
+ "github.com/klauspost/compress/zstd"
)
const (
requestBodyReadInitCap = 512
requestBodyReadMaxInitCap = 1 << 20
+ // maxDecompressedBodySize limits the decompressed request body to 64 MB
+ // to prevent decompression bomb attacks.
+ maxDecompressedBodySize = 64 << 20
)
-// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
+// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based
+// on content length, transparently decoding any Content-Encoding the upstream
+// client used to compress the body (zstd, gzip, deflate).
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if req == nil || req.Body == nil {
return nil, nil
@@ -33,5 +45,49 @@ func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if _, err := io.Copy(buf, req.Body); err != nil {
return nil, err
}
- return buf.Bytes(), nil
+ raw := buf.Bytes()
+
+ enc := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Encoding")))
+ if enc == "" || enc == "identity" {
+ return raw, nil
+ }
+
+ decoded, err := decompressRequestBody(enc, raw)
+ if err != nil {
+ return nil, fmt.Errorf("decode Content-Encoding %q: %w", enc, err)
+ }
+
+ req.Header.Del("Content-Encoding")
+ req.Header.Del("Content-Length")
+ req.ContentLength = int64(len(decoded))
+
+ return decoded, nil
+}
+
+func decompressRequestBody(encoding string, raw []byte) ([]byte, error) {
+ switch encoding {
+ case "zstd":
+ dec, err := zstd.NewReader(bytes.NewReader(raw))
+ if err != nil {
+ return nil, err
+ }
+ defer dec.Close()
+ return io.ReadAll(io.LimitReader(dec, maxDecompressedBodySize))
+ case "gzip", "x-gzip":
+ gr, err := gzip.NewReader(bytes.NewReader(raw))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = gr.Close() }()
+ return io.ReadAll(io.LimitReader(gr, maxDecompressedBodySize))
+ case "deflate":
+ zr, err := zlib.NewReader(bytes.NewReader(raw))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = zr.Close() }()
+ return io.ReadAll(io.LimitReader(zr, maxDecompressedBodySize))
+ default:
+ return nil, errors.New("unsupported Content-Encoding")
+ }
}
diff --git a/backend/internal/pkg/httputil/body_test.go b/backend/internal/pkg/httputil/body_test.go
new file mode 100644
index 00000000..ed8355d5
--- /dev/null
+++ b/backend/internal/pkg/httputil/body_test.go
@@ -0,0 +1,143 @@
+package httputil
+
+import (
+ "bytes"
+ "compress/gzip"
+ "compress/zlib"
+ "net/http"
+ "strings"
+ "testing"
+
+ "github.com/klauspost/compress/zstd"
+)
+
+const samplePayload = `{"model":"gpt-5.5","input":"hi","stream":false}`
+
+func newRequestWithBody(t *testing.T, body []byte, encoding string) *http.Request {
+ t.Helper()
+ req, err := http.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
+ if err != nil {
+ t.Fatalf("NewRequest: %v", err)
+ }
+ if encoding != "" {
+ req.Header.Set("Content-Encoding", encoding)
+ }
+ req.ContentLength = int64(len(body))
+ return req
+}
+
+func TestReadRequestBodyWithPrealloc_PassesThroughIdentity(t *testing.T) {
+ req := newRequestWithBody(t, []byte(samplePayload), "")
+ got, err := ReadRequestBodyWithPrealloc(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if string(got) != samplePayload {
+ t.Fatalf("body mismatch: got %q", got)
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_DecodesZstd(t *testing.T) {
+ enc, _ := zstd.NewWriter(nil)
+ compressed := enc.EncodeAll([]byte(samplePayload), nil)
+ _ = enc.Close()
+
+ req := newRequestWithBody(t, compressed, "zstd")
+ got, err := ReadRequestBodyWithPrealloc(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if string(got) != samplePayload {
+ t.Fatalf("body mismatch: got %q", got)
+ }
+ if req.Header.Get("Content-Encoding") != "" {
+ t.Fatalf("Content-Encoding should be cleared after decoding")
+ }
+ if req.ContentLength != int64(len(samplePayload)) {
+ t.Fatalf("ContentLength not updated: %d", req.ContentLength)
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_DecodesGzip(t *testing.T) {
+ var buf bytes.Buffer
+ gw := gzip.NewWriter(&buf)
+ if _, err := gw.Write([]byte(samplePayload)); err != nil {
+ t.Fatalf("gzip write: %v", err)
+ }
+ if err := gw.Close(); err != nil {
+ t.Fatalf("gzip close: %v", err)
+ }
+
+ req := newRequestWithBody(t, buf.Bytes(), "gzip")
+ got, err := ReadRequestBodyWithPrealloc(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if string(got) != samplePayload {
+ t.Fatalf("body mismatch: got %q", got)
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_DecodesDeflate(t *testing.T) {
+ var buf bytes.Buffer
+ zw := zlib.NewWriter(&buf)
+ if _, err := zw.Write([]byte(samplePayload)); err != nil {
+ t.Fatalf("zlib write: %v", err)
+ }
+ if err := zw.Close(); err != nil {
+ t.Fatalf("zlib close: %v", err)
+ }
+
+ req := newRequestWithBody(t, buf.Bytes(), "deflate")
+ got, err := ReadRequestBodyWithPrealloc(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if string(got) != samplePayload {
+ t.Fatalf("body mismatch: got %q", got)
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_RejectsUnsupportedEncoding(t *testing.T) {
+ req := newRequestWithBody(t, []byte(samplePayload), "br")
+ _, err := ReadRequestBodyWithPrealloc(req)
+ if err == nil {
+ t.Fatal("expected error for unsupported encoding, got nil")
+ }
+ if !strings.Contains(err.Error(), "br") {
+ t.Fatalf("error should mention encoding, got %v", err)
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_RejectsCorruptZstd(t *testing.T) {
+ req := newRequestWithBody(t, []byte("not actually zstd"), "zstd")
+ _, err := ReadRequestBodyWithPrealloc(req)
+ if err == nil {
+ t.Fatal("expected error for corrupt zstd body, got nil")
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_NilBody(t *testing.T) {
+ req, err := http.NewRequest(http.MethodPost, "/v1/responses", nil)
+ if err != nil {
+ t.Fatalf("NewRequest: %v", err)
+ }
+ got, err := ReadRequestBodyWithPrealloc(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got != nil {
+ t.Fatalf("expected nil body, got %q", got)
+ }
+}
+
+func TestReadRequestBodyWithPrealloc_RespectsIdentityEncoding(t *testing.T) {
+ req := newRequestWithBody(t, []byte(samplePayload), "identity")
+ got, err := ReadRequestBodyWithPrealloc(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if string(got) != samplePayload {
+ t.Fatalf("body mismatch: got %q", got)
+ }
+}
diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go
index 49e38bf8..be9f3aae 100644
--- a/backend/internal/pkg/openai/constants.go
+++ b/backend/internal/pkg/openai/constants.go
@@ -15,18 +15,15 @@ type Model struct {
// DefaultModels OpenAI models list
var DefaultModels = []Model{
+ {ID: "gpt-5.5", Object: "model", Created: 1776873600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.5"},
{ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"},
{ID: "gpt-5.4-mini", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Mini"},
- {ID: "gpt-5.4-nano", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Nano"},
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
- {ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
- {ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
- {ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
- {ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
- {ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
- {ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
+ {ID: "gpt-image-1", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 1"},
+ {ID: "gpt-image-1.5", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 1.5"},
+ {ID: "gpt-image-2", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 2"},
}
// DefaultModelIDs returns the default model ID list
@@ -39,7 +36,7 @@ func DefaultModelIDs() []string {
}
// DefaultTestModel default model for testing OpenAI accounts
-const DefaultTestModel = "gpt-5.1-codex"
+const DefaultTestModel = "gpt-5.4"
// DefaultInstructions default instructions for non-Codex CLI requests
// Content loaded from instructions.txt at compile time
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index 24115c33..78f739ac 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -438,6 +438,9 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err
}
+ if _, err := txClient.ExecContext(ctx, "DELETE FROM scheduled_test_plans WHERE account_id = $1", id); err != nil {
+ return err
+ }
if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil {
return err
}
diff --git a/backend/internal/repository/account_repo_compact_extra_test.go b/backend/internal/repository/account_repo_compact_extra_test.go
new file mode 100644
index 00000000..604f392e
--- /dev/null
+++ b/backend/internal/repository/account_repo_compact_extra_test.go
@@ -0,0 +1,14 @@
+package repository
+
+import "testing"
+
+func TestShouldEnqueueSchedulerOutboxForExtraUpdates_CompactCapabilityKeysAreRelevant(t *testing.T) {
+ updates := map[string]any{
+ "openai_compact_supported": true,
+ "openai_compact_checked_at": "2026-04-10T10:00:00Z",
+ }
+
+ if !shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
+ t.Fatalf("expected compact capability updates to enqueue scheduler outbox")
+ }
+}
diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go
index b249bb61..d1cea9eb 100644
--- a/backend/internal/repository/account_repo_integration_test.go
+++ b/backend/internal/repository/account_repo_integration_test.go
@@ -64,6 +64,10 @@ func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket servi
return true, nil
}
+func (s *schedulerCacheRecorder) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
+ return nil
+}
+
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
return nil, nil
}
diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go
new file mode 100644
index 00000000..ef89e5b6
--- /dev/null
+++ b/backend/internal/repository/affiliate_repo.go
@@ -0,0 +1,762 @@
+package repository
+
+import (
+ "context"
+ "crypto/rand"
+ "database/sql"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
+)
+
+const (
+ affiliateCodeLength = 12
+ affiliateCodeMaxAttempts = 12
+)
+
+var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
+
+type affiliateQueryExecer interface {
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+}
+
+type affiliateRepository struct {
+ client *dbent.Client
+}
+
+func NewAffiliateRepository(client *dbent.Client, _ *sql.DB) service.AffiliateRepository {
+ return &affiliateRepository{client: client}
+}
+
+func (r *affiliateRepository) EnsureUserAffiliate(ctx context.Context, userID int64) (*service.AffiliateSummary, error) {
+ if userID <= 0 {
+ return nil, service.ErrUserNotFound
+ }
+ client := clientFromContext(ctx, r.client)
+ return ensureUserAffiliateWithClient(ctx, client, userID)
+}
+
+func (r *affiliateRepository) GetAffiliateByCode(ctx context.Context, code string) (*service.AffiliateSummary, error) {
+ client := clientFromContext(ctx, r.client)
+ return queryAffiliateByCode(ctx, client, code)
+}
+
+func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) {
+ var bound bool
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, inviterID); err != nil {
+ return err
+ }
+
+ res, err := txClient.ExecContext(txCtx,
+ "UPDATE user_affiliates SET inviter_id = $1, updated_at = NOW() WHERE user_id = $2 AND inviter_id IS NULL",
+ inviterID, userID,
+ )
+ if err != nil {
+ return fmt.Errorf("bind inviter: %w", err)
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ bound = false
+ return nil
+ }
+
+ if _, err = txClient.ExecContext(txCtx,
+ "UPDATE user_affiliates SET aff_count = aff_count + 1, updated_at = NOW() WHERE user_id = $1",
+ inviterID,
+ ); err != nil {
+ return fmt.Errorf("increment inviter aff_count: %w", err)
+ }
+ bound = true
+ return nil
+ })
+ if err != nil {
+ return false, err
+ }
+ return bound, nil
+}
+
+func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) {
+ if amount <= 0 {
+ return false, nil
+ }
+
+ var applied bool
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ // freezeHours > 0: add to frozen quota; == 0: add to available quota directly
+ var updateSQL string
+ if freezeHours > 0 {
+ updateSQL = "UPDATE user_affiliates SET aff_frozen_quota = aff_frozen_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
+ } else {
+ updateSQL = "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
+ }
+ res, err := txClient.ExecContext(txCtx, updateSQL, amount, inviterID)
+ if err != nil {
+ return err
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ applied = false
+ return nil
+ }
+
+ if freezeHours > 0 {
+ if _, err = txClient.ExecContext(txCtx, `
+INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at)
+VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`,
+ inviterID, amount, inviteeUserID, freezeHours); err != nil {
+ return fmt.Errorf("insert affiliate accrue ledger: %w", err)
+ }
+ } else {
+ if _, err = txClient.ExecContext(txCtx, `
+INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
+VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
+ return fmt.Errorf("insert affiliate accrue ledger: %w", err)
+ }
+ }
+
+ applied = true
+ return nil
+ })
+ if err != nil {
+ return false, err
+ }
+ return applied, nil
+}
+
+func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) {
+ client := clientFromContext(ctx, r.client)
+ rows, err := client.QueryContext(ctx,
+ `SELECT COALESCE(SUM(amount), 0)::double precision FROM user_affiliate_ledger WHERE user_id = $1 AND source_user_id = $2 AND action = 'accrue'`,
+ inviterID, inviteeUserID)
+ if err != nil {
+ return 0, fmt.Errorf("query accrued rebate from invitee: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+ var total float64
+ if rows.Next() {
+ if err := rows.Scan(&total); err != nil {
+ return 0, err
+ }
+ }
+ return total, rows.Close()
+}
+
+func (r *affiliateRepository) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) {
+ var thawed float64
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ var err error
+ thawed, err = thawFrozenQuotaTx(txCtx, txClient, userID)
+ return err
+ })
+ return thawed, err
+}
+
+// thawFrozenQuotaTx moves matured frozen quota to available quota within an existing tx.
+func thawFrozenQuotaTx(txCtx context.Context, txClient *dbent.Client, userID int64) (float64, error) {
+ rows, err := txClient.QueryContext(txCtx, `
+WITH matured AS (
+ UPDATE user_affiliate_ledger
+ SET frozen_until = NULL, updated_at = NOW()
+ WHERE user_id = $1
+ AND frozen_until IS NOT NULL
+ AND frozen_until <= NOW()
+ RETURNING amount
+)
+SELECT COALESCE(SUM(amount), 0) FROM matured`, userID)
+ if err != nil {
+ return 0, fmt.Errorf("thaw frozen quota: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ var thawed float64
+ if rows.Next() {
+ if err := rows.Scan(&thawed); err != nil {
+ return 0, err
+ }
+ }
+ if err := rows.Close(); err != nil {
+ return 0, err
+ }
+ if thawed <= 0 {
+ return 0, nil
+ }
+
+ _, err = txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_quota = aff_quota + $1,
+ aff_frozen_quota = GREATEST(aff_frozen_quota - $1, 0),
+ updated_at = NOW()
+WHERE user_id = $2`, thawed, userID)
+ if err != nil {
+ return 0, fmt.Errorf("move thawed quota: %w", err)
+ }
+ return thawed, nil
+}
+
+func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
+ var transferred float64
+ var newBalance float64
+
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+
+ // Thaw any matured frozen quota before transfer.
+ if _, err := thawFrozenQuotaTx(txCtx, txClient, userID); err != nil {
+ return fmt.Errorf("thaw before transfer: %w", err)
+ }
+
+ rows, err := txClient.QueryContext(txCtx, `
+WITH claimed AS (
+ SELECT aff_quota::double precision AS amount
+ FROM user_affiliates
+ WHERE user_id = $1
+ AND aff_quota > 0
+ FOR UPDATE
+),
+cleared AS (
+ UPDATE user_affiliates ua
+ SET aff_quota = 0,
+ updated_at = NOW()
+ FROM claimed c
+ WHERE ua.user_id = $1
+ RETURNING c.amount
+)
+SELECT amount
+FROM cleared`, userID)
+ if err != nil {
+ return fmt.Errorf("claim affiliate quota: %w", err)
+ }
+
+ if !rows.Next() {
+ _ = rows.Close()
+ if err := rows.Err(); err != nil {
+ return err
+ }
+ return service.ErrAffiliateQuotaEmpty
+ }
+ if err := rows.Scan(&transferred); err != nil {
+ _ = rows.Close()
+ return err
+ }
+ if err := rows.Close(); err != nil {
+ return err
+ }
+ if transferred <= 0 {
+ return service.ErrAffiliateQuotaEmpty
+ }
+
+ affected, err := txClient.User.Update().
+ Where(user.IDEQ(userID)).
+ AddBalance(transferred).
+ AddTotalRecharged(transferred).
+ Save(txCtx)
+ if err != nil {
+ return fmt.Errorf("credit user balance by affiliate quota: %w", err)
+ }
+ if affected == 0 {
+ return service.ErrUserNotFound
+ }
+
+ newBalance, err = queryUserBalance(txCtx, txClient, userID)
+ if err != nil {
+ return err
+ }
+
+ if _, err = txClient.ExecContext(txCtx, `
+INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
+VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
+ return fmt.Errorf("insert affiliate transfer ledger: %w", err)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return 0, 0, err
+ }
+
+ return transferred, newBalance, nil
+}
+
+func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]service.AffiliateInvitee, error) {
+ if limit <= 0 {
+ limit = 100
+ }
+ client := clientFromContext(ctx, r.client)
+ rows, err := client.QueryContext(ctx, `
+SELECT ua.user_id,
+ COALESCE(u.email, ''),
+ COALESCE(u.username, ''),
+ ua.created_at,
+ COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate
+FROM user_affiliates ua
+LEFT JOIN users u ON u.id = ua.user_id
+LEFT JOIN user_affiliate_ledger ual
+ ON ual.user_id = $1
+ AND ual.source_user_id = ua.user_id
+ AND ual.action = 'accrue'
+WHERE ua.inviter_id = $1
+GROUP BY ua.user_id, u.email, u.username, ua.created_at
+ORDER BY ua.created_at DESC
+LIMIT $2`, inviterID, limit)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ invitees := make([]service.AffiliateInvitee, 0)
+ for rows.Next() {
+ var item service.AffiliateInvitee
+ var createdAt time.Time
+ if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt, &item.TotalRebate); err != nil {
+ return nil, err
+ }
+ item.CreatedAt = &createdAt
+ invitees = append(invitees, item)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return invitees, nil
+}
+
+func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return fn(ctx, tx.Client())
+ }
+
+ tx, err := r.client.Tx(ctx)
+ if err != nil {
+ return fmt.Errorf("begin affiliate transaction: %w", err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := fn(txCtx, tx.Client()); err != nil {
+ return err
+ }
+
+ if err := tx.Commit(); err != nil {
+ return fmt.Errorf("commit affiliate transaction: %w", err)
+ }
+ return nil
+}
+
+func ensureUserAffiliateWithClient(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
+ summary, err := queryAffiliateByUserID(ctx, client, userID)
+ if err == nil {
+ return summary, nil
+ }
+ if !errors.Is(err, service.ErrAffiliateProfileNotFound) {
+ return nil, err
+ }
+
+ for i := 0; i < affiliateCodeMaxAttempts; i++ {
+ code, codeErr := generateAffiliateCode()
+ if codeErr != nil {
+ return nil, codeErr
+ }
+ _, insertErr := client.ExecContext(ctx, `
+INSERT INTO user_affiliates (user_id, aff_code, created_at, updated_at)
+VALUES ($1, $2, NOW(), NOW())
+ON CONFLICT (user_id) DO NOTHING`, userID, code)
+ if insertErr == nil {
+ break
+ }
+ if isAffiliateUniqueViolation(insertErr) {
+ continue
+ }
+ return nil, insertErr
+ }
+
+ return queryAffiliateByUserID(ctx, client, userID)
+}
+
+func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
+ rows, err := client.QueryContext(ctx, `
+SELECT user_id,
+ aff_code,
+ aff_code_custom,
+ aff_rebate_rate_percent,
+ inviter_id,
+ aff_count,
+ aff_quota::double precision,
+ aff_frozen_quota::double precision,
+ aff_history_quota::double precision,
+ created_at,
+ updated_at
+FROM user_affiliates
+WHERE user_id = $1`, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return nil, service.ErrAffiliateProfileNotFound
+ }
+
+ var out service.AffiliateSummary
+ var inviterID sql.NullInt64
+ var rebateRate sql.NullFloat64
+ if err := rows.Scan(
+ &out.UserID,
+ &out.AffCode,
+ &out.AffCodeCustom,
+ &rebateRate,
+ &inviterID,
+ &out.AffCount,
+ &out.AffQuota,
+ &out.AffFrozenQuota,
+ &out.AffHistoryQuota,
+ &out.CreatedAt,
+ &out.UpdatedAt,
+ ); err != nil {
+ return nil, err
+ }
+ if inviterID.Valid {
+ out.InviterID = &inviterID.Int64
+ }
+ if rebateRate.Valid {
+ v := rebateRate.Float64
+ out.AffRebateRatePercent = &v
+ }
+ return &out, nil
+}
+
+func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code string) (*service.AffiliateSummary, error) {
+ rows, err := client.QueryContext(ctx, `
+SELECT user_id,
+ aff_code,
+ aff_code_custom,
+ aff_rebate_rate_percent,
+ inviter_id,
+ aff_count,
+ aff_quota::double precision,
+ aff_frozen_quota::double precision,
+ aff_history_quota::double precision,
+ created_at,
+ updated_at
+FROM user_affiliates
+WHERE aff_code = $1
+LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return nil, service.ErrAffiliateProfileNotFound
+ }
+
+ var out service.AffiliateSummary
+ var inviterID sql.NullInt64
+ var rebateRate sql.NullFloat64
+ if err := rows.Scan(
+ &out.UserID,
+ &out.AffCode,
+ &out.AffCodeCustom,
+ &rebateRate,
+ &inviterID,
+ &out.AffCount,
+ &out.AffQuota,
+ &out.AffFrozenQuota,
+ &out.AffHistoryQuota,
+ &out.CreatedAt,
+ &out.UpdatedAt,
+ ); err != nil {
+ return nil, err
+ }
+ if inviterID.Valid {
+ out.InviterID = &inviterID.Int64
+ }
+ if rebateRate.Valid {
+ v := rebateRate.Float64
+ out.AffRebateRatePercent = &v
+ }
+ return &out, nil
+}
+
+func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID int64) (float64, error) {
+ rows, err := client.QueryContext(ctx,
+ "SELECT balance::double precision FROM users WHERE id = $1 LIMIT 1",
+ userID,
+ )
+ if err != nil {
+ return 0, err
+ }
+ defer func() { _ = rows.Close() }()
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+ return 0, service.ErrUserNotFound
+ }
+ var balance float64
+ if err := rows.Scan(&balance); err != nil {
+ return 0, err
+ }
+ return balance, nil
+}
+
+func generateAffiliateCode() (string, error) {
+ buf := make([]byte, affiliateCodeLength)
+ if _, err := rand.Read(buf); err != nil {
+ return "", fmt.Errorf("generate affiliate code: %w", err)
+ }
+ for i := range buf {
+ buf[i] = affiliateCodeCharset[int(buf[i])%len(affiliateCodeCharset)]
+ }
+ return string(buf), nil
+}
+
+func isAffiliateUniqueViolation(err error) bool {
+ var pqErr *pq.Error
+ if errors.As(err, &pqErr) {
+ return string(pqErr.Code) == "23505"
+ }
+ return false
+}
+
+// UpdateUserAffCode 改写用户的邀请码(自定义专属邀请码)。
+// 唯一性冲突返回 ErrAffiliateCodeTaken。
+func (r *affiliateRepository) UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error {
+ if userID <= 0 {
+ return service.ErrUserNotFound
+ }
+ code := strings.ToUpper(strings.TrimSpace(newCode))
+ if code == "" {
+ return service.ErrAffiliateCodeInvalid
+ }
+
+ return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+ res, err := txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_code = $1,
+ aff_code_custom = true,
+ updated_at = NOW()
+WHERE user_id = $2`, code, userID)
+ if err != nil {
+ if isAffiliateUniqueViolation(err) {
+ return service.ErrAffiliateCodeTaken
+ }
+ return fmt.Errorf("update aff_code: %w", err)
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ return service.ErrUserNotFound
+ }
+ return nil
+ })
+}
+
+// ResetUserAffCode 把 aff_code 还原为系统随机码,并清除 aff_code_custom 标记。
+func (r *affiliateRepository) ResetUserAffCode(ctx context.Context, userID int64) (string, error) {
+ if userID <= 0 {
+ return "", service.ErrUserNotFound
+ }
+ var newCode string
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+ for i := 0; i < affiliateCodeMaxAttempts; i++ {
+ candidate, codeErr := generateAffiliateCode()
+ if codeErr != nil {
+ return codeErr
+ }
+ res, err := txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_code = $1,
+ aff_code_custom = false,
+ updated_at = NOW()
+WHERE user_id = $2`, candidate, userID)
+ if err != nil {
+ if isAffiliateUniqueViolation(err) {
+ continue
+ }
+ return fmt.Errorf("reset aff_code: %w", err)
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ return service.ErrUserNotFound
+ }
+ newCode = candidate
+ return nil
+ }
+ return fmt.Errorf("reset aff_code: exhausted attempts")
+ })
+ if err != nil {
+ return "", err
+ }
+ return newCode, nil
+}
+
+// SetUserRebateRate 设置或清除用户专属返利比例。ratePercent==nil 表示清除(沿用全局)。
+func (r *affiliateRepository) SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
+ if userID <= 0 {
+ return service.ErrUserNotFound
+ }
+ return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+ // nullableArg lets us use a single UPDATE for both "set value" and
+ // "clear" cases — database/sql converts nil interface{} to SQL NULL.
+ res, err := txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_rebate_rate_percent = $1,
+ updated_at = NOW()
+WHERE user_id = $2`, nullableArg(ratePercent), userID)
+ if err != nil {
+ return fmt.Errorf("set aff_rebate_rate_percent: %w", err)
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ return service.ErrUserNotFound
+ }
+ return nil
+ })
+}
+
+// BatchSetUserRebateRate 批量为多个用户设置专属比例(nil 清除)。
+func (r *affiliateRepository) BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
+ if len(userIDs) == 0 {
+ return nil
+ }
+ return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ for _, uid := range userIDs {
+ if uid <= 0 {
+ continue
+ }
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, uid); err != nil {
+ return err
+ }
+ }
+ _, err := txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_rebate_rate_percent = $1,
+ updated_at = NOW()
+WHERE user_id = ANY($2)`, nullableArg(ratePercent), pq.Array(userIDs))
+ if err != nil {
+ return fmt.Errorf("batch set aff_rebate_rate_percent: %w", err)
+ }
+ return nil
+ })
+}
+
+// nullableArg unwraps a *float64 into an interface{} suitable for SQL parameter
+// binding: nil pointer → SQL NULL, non-nil → the float value.
+func nullableArg(v *float64) any {
+ if v == nil {
+ return nil
+ }
+ return *v
+}
+
+// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
+//
+// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索":
+// 空 search 时拼接出的 LIKE 模式为 "%%",匹配所有行;非空时按 ILIKE 子串匹配。
+// 这避免了为两种情况维护两份 SQL 模板。
+func (r *affiliateRepository) ListUsersWithCustomSettings(ctx context.Context, filter service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) {
+ page := filter.Page
+ if page < 1 {
+ page = 1
+ }
+ pageSize := filter.PageSize
+ if pageSize <= 0 || pageSize > 200 {
+ pageSize = 20
+ }
+ offset := (page - 1) * pageSize
+ likePattern := "%" + strings.TrimSpace(filter.Search) + "%"
+
+ const baseFrom = `
+FROM user_affiliates ua
+JOIN users u ON u.id = ua.user_id
+WHERE (ua.aff_code_custom = true OR ua.aff_rebate_rate_percent IS NOT NULL)
+ AND (u.email ILIKE $1 OR u.username ILIKE $1)`
+
+ client := clientFromContext(ctx, r.client)
+
+ total, err := scanInt64(ctx, client, "SELECT COUNT(*)"+baseFrom, likePattern)
+ if err != nil {
+ return nil, 0, fmt.Errorf("count affiliate admin entries: %w", err)
+ }
+
+ listQuery := `
+SELECT ua.user_id,
+ COALESCE(u.email, ''),
+ COALESCE(u.username, ''),
+ ua.aff_code,
+ ua.aff_code_custom,
+ ua.aff_rebate_rate_percent,
+ ua.aff_count` + baseFrom + `
+ORDER BY ua.updated_at DESC
+LIMIT $2 OFFSET $3`
+
+ rows, err := client.QueryContext(ctx, listQuery, likePattern, pageSize, offset)
+ if err != nil {
+ return nil, 0, fmt.Errorf("list affiliate admin entries: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ entries := make([]service.AffiliateAdminEntry, 0)
+ for rows.Next() {
+ var e service.AffiliateAdminEntry
+ var rebate sql.NullFloat64
+ if err := rows.Scan(&e.UserID, &e.Email, &e.Username, &e.AffCode,
+ &e.AffCodeCustom, &rebate, &e.AffCount); err != nil {
+ return nil, 0, err
+ }
+ if rebate.Valid {
+ v := rebate.Float64
+ e.AffRebateRatePercent = &v
+ }
+ entries = append(entries, e)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, 0, err
+ }
+ return entries, total, nil
+}
+
+// scanInt64 runs a query expected to return a single int64 column (e.g. COUNT).
+func scanInt64(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
+ rows, err := client.QueryContext(ctx, query, args...)
+ if err != nil {
+ return 0, err
+ }
+ defer func() { _ = rows.Close() }()
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+ return 0, nil
+ }
+ var v int64
+ if err := rows.Scan(&v); err != nil {
+ return 0, err
+ }
+ return v, nil
+}
diff --git a/backend/internal/repository/affiliate_repo_integration_test.go b/backend/internal/repository/affiliate_repo_integration_test.go
new file mode 100644
index 00000000..697a193b
--- /dev/null
+++ b/backend/internal/repository/affiliate_repo_integration_test.go
@@ -0,0 +1,399 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func querySingleFloat(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) float64 {
+ t.Helper()
+ rows, err := client.QueryContext(ctx, query, args...)
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+
+ require.True(t, rows.Next(), "expected one row")
+ var value float64
+ require.NoError(t, rows.Scan(&value))
+ require.NoError(t, rows.Err())
+ return value
+}
+
+func querySingleInt(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) int {
+ t.Helper()
+ rows, err := client.QueryContext(ctx, query, args...)
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+
+ require.True(t, rows.Next(), "expected one row")
+ var value int
+ require.NoError(t, rows.Scan(&value))
+ require.NoError(t, rows.Err())
+ return value
+}
+
+func TestAffiliateRepository_TransferQuotaToBalance_UsesClaimedQuotaBeforeClear(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ u := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-transfer-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 5.5,
+ Concurrency: 5,
+ })
+
+ affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
+ _, err := client.ExecContext(txCtx, `
+INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
+VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
+ require.NoError(t, err)
+
+ transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
+ require.NoError(t, err)
+ require.InDelta(t, 12.34, transferred, 1e-9)
+ require.InDelta(t, 17.84, balance, 1e-9)
+
+ affQuota := querySingleFloat(t, txCtx, client,
+ "SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", u.ID)
+ require.InDelta(t, 0.0, affQuota, 1e-9)
+
+ persistedBalance := querySingleFloat(t, txCtx, client,
+ "SELECT balance::double precision FROM users WHERE id = $1", u.ID)
+ require.InDelta(t, 17.84, persistedBalance, 1e-9)
+
+ ledgerCount := querySingleInt(t, txCtx, client,
+ "SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
+ require.Equal(t, 1, ledgerCount)
+}
+
+// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
+// cross-layer tx propagation invariant: when AccrueQuota is called with a ctx
+// that already carries a transaction (via dbent.NewTxContext), repo.withTx
+// must reuse that tx rather than opening a nested one. If this invariant
+// breaks, AccrueQuota would commit independently and survive a rollback of
+// the outer tx, which would violate payment_fulfillment's all-or-nothing
+// semantics.
+func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
+ ctx := context.Background()
+
+ outerTx, err := integrationEntClient.Tx(ctx)
+ require.NoError(t, err, "begin outer tx")
+ // Defensive cleanup: if any require.* below fires before the explicit
+ // Rollback, this prevents the tx from leaking until container teardown.
+ // Rollback is idempotent at the driver level (extra rollback returns an
+ // error we ignore).
+ t.Cleanup(func() { _ = outerTx.Rollback() })
+ client := outerTx.Client()
+ txCtx := dbent.NewTxContext(ctx, outerTx)
+
+ inviter := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-inviter-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 5,
+ })
+ invitee := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-invitee-%d@example.com", time.Now().UnixNano()+1),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 5,
+ })
+
+ repo := NewAffiliateRepository(client, integrationDB)
+ _, err = repo.EnsureUserAffiliate(txCtx, inviter.ID)
+ require.NoError(t, err)
+ _, err = repo.EnsureUserAffiliate(txCtx, invitee.ID)
+ require.NoError(t, err)
+
+ bound, err := repo.BindInviter(txCtx, invitee.ID, inviter.ID)
+ require.NoError(t, err)
+ require.True(t, bound, "invitee must bind to inviter")
+
+ applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0)
+ require.NoError(t, err)
+ require.True(t, applied, "AccrueQuota must report applied=true")
+
+ // Visible inside the outer tx.
+ innerQuota := querySingleFloat(t, txCtx, client,
+ "SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", inviter.ID)
+ require.InDelta(t, 3.5, innerQuota, 1e-9)
+
+ // Roll back the outer tx; if AccrueQuota had opened its own inner tx and
+ // committed it, the rows would still be visible to the global client.
+ require.NoError(t, outerTx.Rollback())
+
+ rows, err := integrationEntClient.QueryContext(ctx,
+ "SELECT COUNT(*) FROM user_affiliates WHERE user_id IN ($1, $2)",
+ inviter.ID, invitee.ID)
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+ require.True(t, rows.Next())
+ var postRollbackCount int
+ require.NoError(t, rows.Scan(&postRollbackCount))
+ require.Equal(t, 0, postRollbackCount,
+ "AccrueQuota must propagate the outer tx — found persisted rows after rollback")
+}
+
+func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ u := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-empty-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 3.21,
+ Concurrency: 5,
+ })
+
+ affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
+ _, err := client.ExecContext(txCtx, `
+INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
+VALUES ($1, $2, 0, 0, NOW(), NOW())`, u.ID, affCode)
+ require.NoError(t, err)
+
+ transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
+ require.ErrorIs(t, err, service.ErrAffiliateQuotaEmpty)
+ require.InDelta(t, 0.0, transferred, 1e-9)
+ require.InDelta(t, 0.0, balance, 1e-9)
+
+ persistedBalance := querySingleFloat(t, txCtx, client,
+ "SELECT balance::double precision FROM users WHERE id = $1", u.ID)
+ require.InDelta(t, 3.21, persistedBalance, 1e-9)
+}
+
+// TestAffiliateRepository_AdminCustomCode covers the success path of admin
+// invite-code rewrite + reset within a shared test transaction:
+// - UpdateUserAffCode replaces aff_code, sets aff_code_custom=true, lookup works
+// - the old code can no longer be found
+// - ResetUserAffCode reverts aff_code_custom and assigns a new system-format code
+//
+// The conflict path (duplicate code → ErrAffiliateCodeTaken) lives in its own
+// test because a unique-violation aborts the surrounding Postgres tx, which
+// would poison subsequent assertions in the same transaction.
+func TestAffiliateRepository_AdminCustomCode(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ u := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-custom-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+
+ original, err := repo.EnsureUserAffiliate(txCtx, u.ID)
+ require.NoError(t, err)
+ require.False(t, original.AffCodeCustom, "system-generated codes start as non-custom")
+ originalCode := original.AffCode
+
+ // Rewrite to a custom code
+ customCode := fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)
+ require.NoError(t, repo.UpdateUserAffCode(txCtx, u.ID, customCode))
+
+ updated, err := repo.EnsureUserAffiliate(txCtx, u.ID)
+ require.NoError(t, err)
+ require.Equal(t, customCode, updated.AffCode)
+ require.True(t, updated.AffCodeCustom)
+
+ // Lookup by new custom code finds the user
+ byCode, err := repo.GetAffiliateByCode(txCtx, customCode)
+ require.NoError(t, err)
+ require.Equal(t, u.ID, byCode.UserID)
+
+ // Old system code should no longer match
+ _, err = repo.GetAffiliateByCode(txCtx, originalCode)
+ require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
+
+ // Reset back to a fresh system code, clears custom flag
+ newSysCode, err := repo.ResetUserAffCode(txCtx, u.ID)
+ require.NoError(t, err)
+ require.NotEqual(t, customCode, newSysCode)
+
+ reset, err := repo.EnsureUserAffiliate(txCtx, u.ID)
+ require.NoError(t, err)
+ require.Equal(t, newSysCode, reset.AffCode)
+ require.False(t, reset.AffCodeCustom)
+
+ // The old custom code is now free again
+ _, err = repo.GetAffiliateByCode(txCtx, customCode)
+ require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
+}
+
+// TestAffiliateRepository_AdminCustomCode_Conflict isolates the unique-violation
+// path. PostgreSQL aborts the enclosing tx when a unique constraint fires, so
+// this test must be the only assertion and run in its own tx — production
+// callers each have their own outer tx, so this matches real behavior.
+func TestAffiliateRepository_AdminCustomCode_Conflict(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ taker := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-conflict-taker-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+ requester := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-conflict-req-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+
+ takenCode := fmt.Sprintf("HOT%09d", time.Now().UnixNano()%1_000_000_000)
+ require.NoError(t, repo.UpdateUserAffCode(txCtx, taker.ID, takenCode))
+
+ // Now requester tries to grab the same code → conflict.
+ err := repo.UpdateUserAffCode(txCtx, requester.ID, takenCode)
+ require.ErrorIs(t, err, service.ErrAffiliateCodeTaken)
+}
+
+// TestAffiliateRepository_AdminRebateRate covers per-user exclusive rate
+// set/clear and the Batch variant including NULL semantics.
+func TestAffiliateRepository_AdminRebateRate(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ u1 := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-rate-%d-a@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ u2 := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-rate-%d-b@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+
+ // Set exclusive rate for u1
+ rate := 42.5
+ require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, &rate))
+
+ got, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
+ require.NoError(t, err)
+ require.NotNil(t, got.AffRebateRatePercent)
+ require.InDelta(t, 42.5, *got.AffRebateRatePercent, 1e-9)
+
+ // Clear exclusive rate
+ require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, nil))
+ cleared, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
+ require.NoError(t, err)
+ require.Nil(t, cleared.AffRebateRatePercent)
+
+ // Batch set both users
+ batchRate := 15.0
+ require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, &batchRate))
+
+ for _, uid := range []int64{u1.ID, u2.ID} {
+ v, err := repo.EnsureUserAffiliate(txCtx, uid)
+ require.NoError(t, err)
+ require.NotNil(t, v.AffRebateRatePercent)
+ require.InDelta(t, 15.0, *v.AffRebateRatePercent, 1e-9)
+ }
+
+ // Batch clear
+ require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, nil))
+ for _, uid := range []int64{u1.ID, u2.ID} {
+ v, err := repo.EnsureUserAffiliate(txCtx, uid)
+ require.NoError(t, err)
+ require.Nil(t, v.AffRebateRatePercent)
+ }
+}
+
+// TestAffiliateRepository_ListUsersWithCustomSettings verifies the admin list
+// only includes users with at least one override applied.
+func TestAffiliateRepository_ListUsersWithCustomSettings(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ // User without any custom config — should NOT appear in the list.
+ plainEmail := fmt.Sprintf("affiliate-plain-%d@example.com", time.Now().UnixNano())
+ uPlain := mustCreateUser(t, client, &service.User{
+ Email: plainEmail, PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+ _, err := repo.EnsureUserAffiliate(txCtx, uPlain.ID)
+ require.NoError(t, err)
+
+ // User with a custom code — should appear.
+ uCode := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-codeonly-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+ require.NoError(t, repo.UpdateUserAffCode(txCtx, uCode.ID, fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)))
+
+ // User with only an exclusive rate — should appear.
+ uRate := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-rateonly-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+ r := 33.3
+ require.NoError(t, repo.SetUserRebateRate(txCtx, uRate.ID, &r))
+
+ entries, total, err := repo.ListUsersWithCustomSettings(txCtx, service.AffiliateAdminFilter{
+ Page: 1, PageSize: 100,
+ })
+ require.NoError(t, err)
+
+ // Build a quick lookup to assert per-user attributes (other tests may have
+ // inserted custom rows in the same DB; we only care about our 3).
+ byUserID := make(map[int64]service.AffiliateAdminEntry, len(entries))
+ for _, e := range entries {
+ byUserID[e.UserID] = e
+ }
+
+ require.NotContains(t, byUserID, uPlain.ID, "users without overrides must not appear")
+
+ codeEntry, ok := byUserID[uCode.ID]
+ require.True(t, ok, "custom-code user missing from list")
+ require.True(t, codeEntry.AffCodeCustom)
+ require.Nil(t, codeEntry.AffRebateRatePercent)
+
+ rateEntry, ok := byUserID[uRate.ID]
+ require.True(t, ok, "custom-rate user missing from list")
+ require.False(t, rateEntry.AffCodeCustom)
+ require.NotNil(t, rateEntry.AffRebateRatePercent)
+ require.InDelta(t, 33.3, *rateEntry.AffRebateRatePercent, 1e-9)
+
+ require.GreaterOrEqual(t, total, int64(2), "total must include at least our 2 custom rows")
+}
diff --git a/backend/internal/repository/announcement_read_repo.go b/backend/internal/repository/announcement_read_repo.go
index 2dc346b1..5268ec45 100644
--- a/backend/internal/repository/announcement_read_repo.go
+++ b/backend/internal/repository/announcement_read_repo.go
@@ -19,13 +19,17 @@ func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementRea
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
client := clientFromContext(ctx, r.client)
- return client.AnnouncementRead.Create().
+ err := client.AnnouncementRead.Create().
SetAnnouncementID(announcementID).
SetUserID(userID).
SetReadAt(readAt).
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
DoNothing().
Exec(ctx)
+ if isSQLNoRowsError(err) {
+ return nil
+ }
+ return err
}
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index 38ea9bde..3a527405 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -149,6 +149,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user.FieldBalanceNotifyThreshold,
user.FieldBalanceNotifyExtraEmails,
user.FieldTotalRecharged,
+ user.FieldSignupSource,
+ user.FieldLastLoginAt,
+ user.FieldLastActiveAt,
+ user.FieldRpmLimit,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
@@ -175,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldAllowMessagesDispatch,
group.FieldDefaultMappedModel,
group.FieldMessagesDispatchModelConfig,
+ group.FieldRpmLimit,
)
}).
Only(ctx)
@@ -656,6 +661,9 @@ func userEntityToService(u *dbent.User) *service.User {
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
+ SignupSource: u.SignupSource,
+ LastLoginAt: u.LastLoginAt,
+ LastActiveAt: u.LastActiveAt,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
@@ -663,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User {
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
TotalRecharged: u.TotalRecharged,
+ RPMLimit: u.RpmLimit,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
@@ -707,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RequirePrivacySet: g.RequirePrivacySet,
DefaultMappedModel: g.DefaultMappedModel,
MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
+ RPMLimit: g.RpmLimit,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
diff --git a/backend/internal/repository/auth_identity_compat_backfill_integration_test.go b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go
new file mode 100644
index 00000000..7e34777a
--- /dev/null
+++ b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go
@@ -0,0 +1,80 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityCompatBackfillMigration_AllowsLongReportTypes(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration108Path := filepath.Join("..", "..", "migrations", "108_auth_identity_foundation_core.sql")
+ migration108SQL, err := os.ReadFile(migration108Path)
+ require.NoError(t, err)
+
+ migration108aPath := filepath.Join("..", "..", "migrations", "108a_widen_auth_identity_migration_report_type.sql")
+ migration108aSQL, err := os.ReadFile(migration108aPath)
+ require.NoError(t, err)
+
+ migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql")
+ migration109SQL, err := os.ReadFile(migration109Path)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, `
+DROP TABLE IF EXISTS auth_identity_migration_reports CASCADE;
+DROP TABLE IF EXISTS auth_identity_channels CASCADE;
+DROP TABLE IF EXISTS identity_adoption_decisions CASCADE;
+DROP TABLE IF EXISTS pending_auth_sessions CASCADE;
+DROP TABLE IF EXISTS auth_identities CASCADE;
+
+ALTER TABLE users
+ DROP COLUMN IF EXISTS signup_source,
+ DROP COLUMN IF EXISTS last_login_at,
+ DROP COLUMN IF EXISTS last_active_at;
+`)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration108SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration108aSQL))
+ require.NoError(t, err)
+
+ var userID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('oidc-demo-subject@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&userID))
+
+ _, err = tx.ExecContext(ctx, string(migration109SQL))
+ require.NoError(t, err)
+
+ var reportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
+ AND report_key = $1
+`, strconv.FormatInt(userID, 10)).Scan(&reportCount))
+ require.Equal(t, 1, reportCount)
+
+ var reportTypeLimit int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT character_maximum_length
+FROM information_schema.columns
+WHERE table_schema = 'public'
+ AND table_name = 'auth_identity_migration_reports'
+ AND column_name = 'report_type'
+`).Scan(&reportTypeLimit))
+ require.GreaterOrEqual(t, reportTypeLimit, 45)
+
+ require.NotZero(t, userID)
+}
diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
new file mode 100644
index 00000000..e64934c5
--- /dev/null
+++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
@@ -0,0 +1,959 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ var linuxDoUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoUserID))
+
+ var wechatUnionUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-union@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatUnionUserID))
+
+ var wechatOpenIDOnlyUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-openid@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatOpenIDOnlyUserID))
+
+ var syntheticAuthIdentityID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'openid-synthetic', '{"backfill_source":"synthetic_email"}'::jsonb)
+RETURNING id`, wechatOpenIDOnlyUserID).Scan(&syntheticAuthIdentityID))
+
+ var linuxDoLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-user-1', NULL, 'linux-user', 'Linux User', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoUserID).Scan(&linuxDoLegacyID))
+
+ var wechatUnionLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-union-1', 'union-1', 'wechat-union-user', 'WeChat Union User', '{"channel":"oa","appid":"wx-app-1"}')
+RETURNING id
+`, wechatUnionUserID).Scan(&wechatUnionLegacyID))
+
+ var wechatOpenIDLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-only-1', NULL, 'wechat-openid-user', 'WeChat OpenID User', '{"channel":"oa","appid":"wx-app-2"}')
+RETURNING id
+`, wechatOpenIDOnlyUserID).Scan(&wechatOpenIDLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxDoCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-user-1'
+`, linuxDoUserID).Scan(&linuxDoCount))
+ require.Equal(t, 1, linuxDoCount)
+
+ var wechatSubject string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT provider_subject
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND provider_subject = 'union-1'
+`, wechatUnionUserID).Scan(&wechatSubject))
+ require.Equal(t, "union-1", wechatSubject)
+
+ var wechatChannelCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_channels channel
+JOIN auth_identities ai ON ai.id = channel.identity_id
+WHERE ai.user_id = $1
+ AND channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = 'oa'
+ AND channel.channel_app_id = 'wx-app-1'
+ AND channel.channel_subject = 'openid-union-1'
+`, wechatUnionUserID).Scan(&wechatChannelCount))
+ require.Equal(t, 1, wechatChannelCount)
+
+ var legacyOpenIDOnlyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDLegacyID, 10)).Scan(&legacyOpenIDOnlyReportCount))
+ require.Equal(t, 1, legacyOpenIDOnlyReportCount)
+
+ var syntheticReviewCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "synthetic_auth_identity:"+strconv.FormatInt(syntheticAuthIdentityID, 10)).Scan(&syntheticReviewCount))
+ require.Equal(t, 1, syntheticReviewCount)
+
+ var unionLegacyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatUnionLegacyID, 10)).Scan(&unionLegacyReportCount))
+ require.Zero(t, unionLegacyReportCount)
+ require.NotZero(t, linuxDoLegacyID)
+}
+
+func TestAuthIdentityLegacyExternalBackfillMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ var beforeCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+`).Scan(&beforeCount))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var afterCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+ `).Scan(&afterCount))
+ require.Equal(t, beforeCount, afterCount)
+}
+
+func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectMetadata(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migration115SQL, err := os.ReadFile(migration115Path)
+ require.NoError(t, err)
+
+ migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migration116SQL, err := os.ReadFile(migration116Path)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ var linuxDoMalformedUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-malformed@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoMalformedUserID))
+
+ var linuxDoArrayUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-array@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoArrayUserID))
+
+ var wechatUnionArrayUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-array@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatUnionArrayUserID))
+
+ var wechatOpenIDArrayUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-openid-array@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatOpenIDArrayUserID))
+
+ var linuxDoMalformedLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-malformed', NULL, 'legacy-linuxdo-malformed', 'Legacy LinuxDo Malformed', '{invalid')
+RETURNING id
+`, linuxDoMalformedUserID).Scan(&linuxDoMalformedLegacyID))
+
+ var linuxDoArrayLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-array', NULL, 'legacy-linuxdo-array', 'Legacy LinuxDo Array', '["legacy-linuxdo-array"]')
+RETURNING id
+`, linuxDoArrayUserID).Scan(&linuxDoArrayLegacyID))
+
+ var wechatUnionArrayLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-array', 'union-array', 'legacy-wechat-array', 'Legacy WeChat Array', '["legacy-wechat-array"]')
+RETURNING id
+`, wechatUnionArrayUserID).Scan(&wechatUnionArrayLegacyID))
+
+ var wechatOpenIDArrayLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-array-only', NULL, 'legacy-wechat-array-only', 'Legacy WeChat Array Only', '["legacy-wechat-openid-array"]')
+RETURNING id
+`, wechatOpenIDArrayUserID).Scan(&wechatOpenIDArrayLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migration115SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration116SQL))
+ require.NoError(t, err)
+
+ var linuxDoMalformedMetadataType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(metadata)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-malformed'
+`, linuxDoMalformedUserID).Scan(&linuxDoMalformedMetadataType))
+ require.Equal(t, "object", linuxDoMalformedMetadataType)
+
+ var linuxDoArrayMetadataType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(metadata)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-array'
+`, linuxDoArrayUserID).Scan(&linuxDoArrayMetadataType))
+ require.Equal(t, "object", linuxDoArrayMetadataType)
+
+ var wechatUnionArrayMetadataType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(metadata)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND provider_subject = 'union-array'
+`, wechatUnionArrayUserID).Scan(&wechatUnionArrayMetadataType))
+ require.Equal(t, "object", wechatUnionArrayMetadataType)
+
+ var invalidJSONReportDetailsType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(details)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(linuxDoMalformedLegacyID, 10)).Scan(&invalidJSONReportDetailsType))
+ require.Equal(t, "object", invalidJSONReportDetailsType)
+
+ var openIDOnlyReportDetailsType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(details)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDArrayLegacyID, 10)).Scan(&openIDOnlyReportDetailsType))
+ require.Equal(t, "object", openIDOnlyReportDetailsType)
+
+ var preservedArrayMetadataCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE id IN (
+ SELECT id
+ FROM auth_identities
+ WHERE (user_id = $1 AND provider_subject = 'linuxdo-array')
+ OR (user_id = $2 AND provider_subject = 'union-array')
+)
+ AND metadata ? '_legacy_metadata_raw_json'
+`, linuxDoArrayUserID, wechatUnionArrayUserID).Scan(&preservedArrayMetadataCount))
+ require.Equal(t, 2, preservedArrayMetadataCount)
+
+ require.NotZero(t, linuxDoArrayLegacyID)
+ require.NotZero(t, wechatUnionArrayLegacyID)
+}
+
+func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngradesInvalidJSON(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ userIDs := make([]int64, 0, 8)
+ for _, email := range []string{
+ "linuxdo-conflict-legacy@example.com",
+ "linuxdo-conflict-owner@example.com",
+ "wechat-conflict-legacy@example.com",
+ "wechat-conflict-owner@example.com",
+ "wechat-channel-legacy@example.com",
+ "wechat-channel-owner@example.com",
+ "linuxdo-invalid-json@example.com",
+ "wechat-openid-invalid-json@example.com",
+ } {
+ var userID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ($1, 'hash', 'user', 'active', 0, 1)
+RETURNING id`, email).Scan(&userID))
+ userIDs = append(userIDs, userID)
+ }
+
+ linuxdoConflictLegacyUserID := userIDs[0]
+ linuxdoConflictOwnerUserID := userIDs[1]
+ wechatConflictLegacyUserID := userIDs[2]
+ wechatConflictOwnerUserID := userIDs[3]
+ wechatChannelLegacyUserID := userIDs[4]
+ wechatChannelOwnerUserID := userIDs[5]
+ linuxdoInvalidJSONUserID := userIDs[6]
+ wechatInvalidOpenIDUserID := userIDs[7]
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'linuxdo', 'linuxdo', 'linuxdo-conflict', '{}'::jsonb)
+RETURNING id`, linuxdoConflictOwnerUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'union-conflict', '{}'::jsonb)
+RETURNING id`, wechatConflictOwnerUserID).Scan(new(int64)))
+
+ var wechatChannelOwnerIdentityID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'union-channel-owner', '{}'::jsonb)
+RETURNING id`, wechatChannelOwnerUserID).Scan(&wechatChannelOwnerIdentityID))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+VALUES ($1, 'wechat', 'wechat-main', 'oa', 'wx-app-conflict', 'openid-channel-conflict', '{}'::jsonb)
+RETURNING id`, wechatChannelOwnerIdentityID).Scan(new(int64)))
+
+ var linuxdoConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-conflict', NULL, 'legacy-linuxdo', 'Legacy LinuxDo Conflict', '{"source":"legacy"}')
+RETURNING id
+`, linuxdoConflictLegacyUserID).Scan(&linuxdoConflictLegacyID))
+
+ var wechatConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-union-conflict', 'union-conflict', 'legacy-wechat', 'Legacy WeChat Conflict', '{"channel":"oa","appid":"wx-app-conflict-canon"}')
+RETURNING id
+`, wechatConflictLegacyUserID).Scan(&wechatConflictLegacyID))
+
+ var wechatChannelConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-channel-conflict', 'union-channel-legacy', 'legacy-wechat-channel', 'Legacy WeChat Channel Conflict', '{"channel":"oa","appid":"wx-app-conflict"}')
+RETURNING id
+`, wechatChannelLegacyUserID).Scan(&wechatChannelConflictLegacyID))
+
+ var linuxdoInvalidJSONLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-invalid-json', NULL, 'legacy-linuxdo-invalid', 'Legacy LinuxDo Invalid JSON', '{invalid')
+RETURNING id
+`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidJSONLegacyID))
+
+ var wechatInvalidOpenIDLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-invalid-json-only', NULL, 'legacy-wechat-invalid', 'Legacy WeChat Invalid JSON', '{still-invalid')
+RETURNING id
+`, wechatInvalidOpenIDUserID).Scan(&wechatInvalidOpenIDLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxdoConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(linuxdoConflictLegacyID, 10)).Scan(&linuxdoConflictReportCount))
+ require.Equal(t, 1, linuxdoConflictReportCount)
+
+ var wechatConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatConflictLegacyID, 10)).Scan(&wechatConflictReportCount))
+ require.Equal(t, 1, wechatConflictReportCount)
+
+ var channelConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_channel_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatChannelConflictLegacyID, 10)).Scan(&channelConflictReportCount))
+ require.Equal(t, 1, channelConflictReportCount)
+
+ var invalidJSONReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
+ AND report_key IN ($1, $2)
+`, "legacy_external_identity:"+strconv.FormatInt(linuxdoInvalidJSONLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&invalidJSONReportCount))
+ require.Equal(t, 2, invalidJSONReportCount)
+
+ var linuxdoInvalidIdentityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-invalid-json'
+`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidIdentityCount))
+ require.Equal(t, 1, linuxdoInvalidIdentityCount)
+
+ var wechatOpenIDOnlyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&wechatOpenIDOnlyReportCount))
+ require.Equal(t, 1, wechatOpenIDOnlyReportCount)
+}
+
+func TestAuthIdentityLegacyExternalSafetyMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ var beforeCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+`).Scan(&beforeCount))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var afterCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+ `).Scan(&afterCount))
+ require.Equal(t, beforeCount, afterCount)
+}
+
+func TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ var linuxDoFirstUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoFirstUserID))
+
+ var linuxDoSecondUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoSecondUserID))
+
+ var wechatFirstUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatFirstUserID))
+
+ var wechatSecondUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatSecondUserID))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoFirstUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoSecondUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}')
+RETURNING id
+`, wechatFirstUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}')
+RETURNING id
+`, wechatSecondUserID).Scan(new(int64)))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxDoIdentityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-ambiguous-subject'
+`).Scan(&linuxDoIdentityCount))
+ require.Zero(t, linuxDoIdentityCount)
+
+ var wechatIdentityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND provider_subject = 'union-ambiguous-subject'
+`).Scan(&wechatIdentityCount))
+ require.Zero(t, wechatIdentityCount)
+
+ var wechatChannelCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_channels
+WHERE provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND channel = 'oa'
+ AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b')
+`).Scan(&wechatChannelCount))
+ require.Zero(t, wechatChannelCount)
+}
+
+func TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migration115SQL, err := os.ReadFile(migration115Path)
+ require.NoError(t, err)
+
+ migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migration116SQL, err := os.ReadFile(migration116Path)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ var linuxDoFirstUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoFirstUserID))
+
+ var linuxDoSecondUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoSecondUserID))
+
+ var wechatFirstUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatFirstUserID))
+
+ var wechatSecondUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatSecondUserID))
+
+ var linuxDoFirstLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoFirstUserID).Scan(&linuxDoFirstLegacyID))
+
+ var linuxDoSecondLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoSecondUserID).Scan(&linuxDoSecondLegacyID))
+
+ var wechatFirstLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}')
+RETURNING id
+`, wechatFirstUserID).Scan(&wechatFirstLegacyID))
+
+ var wechatSecondLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}')
+RETURNING id
+`, wechatSecondUserID).Scan(&wechatSecondLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migration115SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration116SQL))
+ require.NoError(t, err)
+
+ var identityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject')
+ OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject')
+`).Scan(&identityCount))
+ require.Zero(t, identityCount)
+
+ var conflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key IN ($1, $2, $3, $4)
+`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&conflictReportCount))
+ require.Equal(t, 4, conflictReportCount)
+
+ var winnerAttributedReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key IN ($1, $2, $3, $4)
+ AND details ->> 'existing_identity_id' IS NOT NULL
+`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&winnerAttributedReportCount))
+ require.Zero(t, winnerAttributedReportCount)
+}
+
+func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration108aPath := filepath.Join("..", "..", "migrations", "108a_widen_auth_identity_migration_report_type.sql")
+ migration108aSQL, err := os.ReadFile(migration108aPath)
+ require.NoError(t, err)
+
+ migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql")
+ migration109SQL, err := os.ReadFile(migration109Path)
+ require.NoError(t, err)
+
+ migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migration116SQL, err := os.ReadFile(migration116Path)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ _, err = tx.ExecContext(ctx, `
+ALTER TABLE auth_identity_migration_reports
+ALTER COLUMN report_type TYPE VARCHAR(40);
+`)
+ require.NoError(t, err)
+
+ var oidcSyntheticUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('oidc-before-121@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&oidcSyntheticUserID))
+
+ var linuxdoLegacyUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-before-121@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxdoLegacyUserID))
+
+ var invalidMetadataLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-before-121', NULL, 'legacy-linuxdo-before-121', 'Legacy LinuxDo Before 121', '{invalid')
+RETURNING id
+`, linuxdoLegacyUserID).Scan(&invalidMetadataLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migration108aSQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration109SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration116SQL))
+ require.NoError(t, err)
+
+ var reportTypeWidth int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT character_maximum_length
+FROM information_schema.columns
+WHERE table_schema = 'public'
+ AND table_name = 'auth_identity_migration_reports'
+ AND column_name = 'report_type'
+`).Scan(&reportTypeWidth))
+ require.Equal(t, 80, reportTypeWidth)
+
+ var oidcSyntheticRecoveryReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
+ AND report_key = $1
+`, strconv.FormatInt(oidcSyntheticUserID, 10)).Scan(&oidcSyntheticRecoveryReportCount))
+ require.Equal(t, 1, oidcSyntheticRecoveryReportCount)
+
+ var invalidMetadataReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(invalidMetadataLegacyID, 10)).Scan(&invalidMetadataReportCount))
+ require.Equal(t, 1, invalidMetadataReportCount)
+}
+
+func prepareLegacyExternalIdentitiesTable(t *testing.T, tx *sql.Tx, ctx context.Context) {
+ t.Helper()
+
+ _, err := tx.ExecContext(ctx, `
+CREATE TABLE IF NOT EXISTS user_external_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL,
+ provider TEXT NOT NULL,
+ provider_user_id TEXT NOT NULL,
+ provider_union_id TEXT NULL,
+ provider_username TEXT NOT NULL DEFAULT '',
+ display_name TEXT NOT NULL DEFAULT '',
+ profile_url TEXT NOT NULL DEFAULT '',
+ avatar_url TEXT NOT NULL DEFAULT '',
+ metadata TEXT NOT NULL DEFAULT '{}',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+`)
+ require.NoError(t, err)
+}
+
+func truncateAuthIdentityLegacyFixtureTables(t *testing.T, tx *sql.Tx, ctx context.Context) {
+ t.Helper()
+
+ _, err := tx.ExecContext(ctx, `
+TRUNCATE TABLE
+ auth_identity_channels,
+ identity_adoption_decisions,
+ pending_auth_sessions,
+ auth_identities,
+ auth_identity_migration_reports,
+ user_provider_default_grants,
+ user_avatars,
+ user_external_identities,
+ users
+RESTART IDENTITY CASCADE;
+`)
+ require.NoError(t, err)
+}
diff --git a/backend/internal/repository/channel_monitor_repo.go b/backend/internal/repository/channel_monitor_repo.go
new file mode 100644
index 00000000..800ee43b
--- /dev/null
+++ b/backend/internal/repository/channel_monitor_repo.go
@@ -0,0 +1,755 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
+)
+
+// channelMonitorRepository 实现 service.ChannelMonitorRepository。
+//
+// 选型说明:
+// - CRUD 走 ent,复用项目的事务上下文支持
+// - 聚合查询(latest per model / availability)走原生 SQL,避免 ent 在 GROUP BY 上
+// 的样板代码,并保证索引能被命中
+type channelMonitorRepository struct {
+ client *dbent.Client
+ db *sql.DB
+}
+
+// NewChannelMonitorRepository 创建仓储实例。
+func NewChannelMonitorRepository(client *dbent.Client, db *sql.DB) service.ChannelMonitorRepository {
+ return &channelMonitorRepository{client: client, db: db}
+}
+
+// ---------- CRUD ----------
+
+func (r *channelMonitorRepository) Create(ctx context.Context, m *service.ChannelMonitor) error {
+ client := clientFromContext(ctx, r.client)
+ builder := client.ChannelMonitor.Create().
+ SetName(m.Name).
+ SetProvider(channelmonitor.Provider(m.Provider)).
+ SetEndpoint(m.Endpoint).
+ SetAPIKeyEncrypted(m.APIKey). // 调用方传入的已是密文
+ SetPrimaryModel(m.PrimaryModel).
+ SetExtraModels(emptySliceIfNil(m.ExtraModels)).
+ SetGroupName(m.GroupName).
+ SetEnabled(m.Enabled).
+ SetIntervalSeconds(m.IntervalSeconds).
+ SetCreatedBy(m.CreatedBy).
+ SetExtraHeaders(emptyHeadersIfNilRepo(m.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(m.BodyOverrideMode))
+ if m.TemplateID != nil {
+ builder = builder.SetTemplateID(*m.TemplateID)
+ }
+ if m.BodyOverride != nil {
+ builder = builder.SetBodyOverride(m.BodyOverride)
+ }
+
+ created, err := builder.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
+ }
+ m.ID = created.ID
+ m.CreatedAt = created.CreatedAt
+ m.UpdatedAt = created.UpdatedAt
+ return nil
+}
+
+func (r *channelMonitorRepository) GetByID(ctx context.Context, id int64) (*service.ChannelMonitor, error) {
+ row, err := r.client.ChannelMonitor.Query().
+ Where(channelmonitor.IDEQ(id)).
+ Only(ctx)
+ if err != nil {
+ return nil, translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
+ }
+ return entToServiceMonitor(row), nil
+}
+
+func (r *channelMonitorRepository) Update(ctx context.Context, m *service.ChannelMonitor) error {
+ client := clientFromContext(ctx, r.client)
+ updater := client.ChannelMonitor.UpdateOneID(m.ID).
+ SetName(m.Name).
+ SetProvider(channelmonitor.Provider(m.Provider)).
+ SetEndpoint(m.Endpoint).
+ SetAPIKeyEncrypted(m.APIKey).
+ SetPrimaryModel(m.PrimaryModel).
+ SetExtraModels(emptySliceIfNil(m.ExtraModels)).
+ SetGroupName(m.GroupName).
+ SetEnabled(m.Enabled).
+ SetIntervalSeconds(m.IntervalSeconds).
+ SetExtraHeaders(emptyHeadersIfNilRepo(m.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(m.BodyOverrideMode))
+ if m.TemplateID != nil {
+ updater = updater.SetTemplateID(*m.TemplateID)
+ } else {
+ updater = updater.ClearTemplateID()
+ }
+ if m.BodyOverride != nil {
+ updater = updater.SetBodyOverride(m.BodyOverride)
+ } else {
+ updater = updater.ClearBodyOverride()
+ }
+
+ updated, err := updater.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
+ }
+ m.UpdatedAt = updated.UpdatedAt
+ return nil
+}
+
+func (r *channelMonitorRepository) Delete(ctx context.Context, id int64) error {
+ client := clientFromContext(ctx, r.client)
+ if err := client.ChannelMonitor.DeleteOneID(id).Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
+ }
+ return nil
+}
+
+func (r *channelMonitorRepository) List(ctx context.Context, params service.ChannelMonitorListParams) ([]*service.ChannelMonitor, int64, error) {
+ q := r.client.ChannelMonitor.Query()
+ if params.Provider != "" {
+ q = q.Where(channelmonitor.ProviderEQ(channelmonitor.Provider(params.Provider)))
+ }
+ if params.Enabled != nil {
+ q = q.Where(channelmonitor.EnabledEQ(*params.Enabled))
+ }
+ if s := strings.TrimSpace(params.Search); s != "" {
+ q = q.Where(channelmonitor.Or(
+ channelmonitor.NameContainsFold(s),
+ channelmonitor.GroupNameContainsFold(s),
+ channelmonitor.PrimaryModelContainsFold(s),
+ ))
+ }
+
+ total, err := q.Count(ctx)
+ if err != nil {
+ return nil, 0, fmt.Errorf("count monitors: %w", err)
+ }
+
+ pageSize := params.PageSize
+ if pageSize <= 0 {
+ pageSize = 20
+ }
+ page := params.Page
+ if page <= 0 {
+ page = 1
+ }
+
+ rows, err := q.
+ Order(dbent.Desc(channelmonitor.FieldID)).
+ Offset((page - 1) * pageSize).
+ Limit(pageSize).
+ All(ctx)
+ if err != nil {
+ return nil, 0, fmt.Errorf("list monitors: %w", err)
+ }
+
+ out := make([]*service.ChannelMonitor, 0, len(rows))
+ for _, row := range rows {
+ out = append(out, entToServiceMonitor(row))
+ }
+ return out, int64(total), nil
+}
+
+// ---------- 调度器辅助 ----------
+
+func (r *channelMonitorRepository) ListEnabled(ctx context.Context) ([]*service.ChannelMonitor, error) {
+ rows, err := r.client.ChannelMonitor.Query().
+ Where(channelmonitor.EnabledEQ(true)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list enabled monitors: %w", err)
+ }
+ out := make([]*service.ChannelMonitor, 0, len(rows))
+ for _, row := range rows {
+ out = append(out, entToServiceMonitor(row))
+ }
+ return out, nil
+}
+
+func (r *channelMonitorRepository) MarkChecked(ctx context.Context, id int64, checkedAt time.Time) error {
+ client := clientFromContext(ctx, r.client)
+ if err := client.ChannelMonitor.UpdateOneID(id).
+ SetLastCheckedAt(checkedAt).
+ Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
+ }
+ return nil
+}
+
+func (r *channelMonitorRepository) InsertHistoryBatch(ctx context.Context, rows []*service.ChannelMonitorHistoryRow) error {
+ if len(rows) == 0 {
+ return nil
+ }
+ client := clientFromContext(ctx, r.client)
+ bulk := make([]*dbent.ChannelMonitorHistoryCreate, 0, len(rows))
+ for _, row := range rows {
+ c := client.ChannelMonitorHistory.Create().
+ SetMonitorID(row.MonitorID).
+ SetModel(row.Model).
+ SetStatus(channelmonitorhistory.Status(row.Status)).
+ SetMessage(row.Message).
+ SetCheckedAt(row.CheckedAt)
+ if row.LatencyMs != nil {
+ c = c.SetLatencyMs(*row.LatencyMs)
+ }
+ if row.PingLatencyMs != nil {
+ c = c.SetPingLatencyMs(*row.PingLatencyMs)
+ }
+ bulk = append(bulk, c)
+ }
+ if _, err := client.ChannelMonitorHistory.CreateBulk(bulk...).Save(ctx); err != nil {
+ return fmt.Errorf("insert history bulk: %w", err)
+ }
+ return nil
+}
+
+// DeleteHistoryBefore 物理删 checked_at < before 的明细,分批 channelMonitorPruneBatchSize 行一批,
+// 避免单事务删除过多引起锁/WAL 压力。借助 (checked_at) 索引定位小批 id,再按 id 删。
+func (r *channelMonitorRepository) DeleteHistoryBefore(ctx context.Context, before time.Time) (int64, error) {
+ return deleteChannelMonitorBatched(ctx, r.db, channelMonitorPruneHistorySQL, before)
+}
+
+// ListHistory 按 checked_at 倒序返回某个监控的最近 N 条历史记录。
+// model 为空时不过滤;非空时只返回该模型的记录。
+func (r *channelMonitorRepository) ListHistory(ctx context.Context, monitorID int64, model string, limit int) ([]*service.ChannelMonitorHistoryEntry, error) {
+ q := r.client.ChannelMonitorHistory.Query().
+ Where(channelmonitorhistory.MonitorIDEQ(monitorID))
+ if strings.TrimSpace(model) != "" {
+ q = q.Where(channelmonitorhistory.ModelEQ(model))
+ }
+ rows, err := q.
+ Order(dbent.Desc(channelmonitorhistory.FieldCheckedAt)).
+ Limit(limit).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list history: %w", err)
+ }
+ out := make([]*service.ChannelMonitorHistoryEntry, 0, len(rows))
+ for _, row := range rows {
+ entry := &service.ChannelMonitorHistoryEntry{
+ ID: row.ID,
+ Model: row.Model,
+ Status: string(row.Status),
+ LatencyMs: row.LatencyMs,
+ PingLatencyMs: row.PingLatencyMs,
+ Message: row.Message,
+ CheckedAt: row.CheckedAt,
+ }
+ out = append(out, entry)
+ }
+ return out, nil
+}
+
+// ---------- 用户视图聚合(原生 SQL) ----------
+
+// ListLatestPerModel 用 DISTINCT ON 取每个 (monitor_id, model) 的最近一条记录。
+// 借助 (monitor_id, model, checked_at DESC) 索引可走 Index Scan。
+func (r *channelMonitorRepository) ListLatestPerModel(ctx context.Context, monitorID int64) ([]*service.ChannelMonitorLatest, error) {
+ const q = `
+ SELECT DISTINCT ON (model)
+ model, status, latency_ms, ping_latency_ms, checked_at
+ FROM channel_monitor_histories
+ WHERE monitor_id = $1
+ ORDER BY model, checked_at DESC
+ `
+ rows, err := r.db.QueryContext(ctx, q, monitorID)
+ if err != nil {
+ return nil, fmt.Errorf("query latest per model: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ out := make([]*service.ChannelMonitorLatest, 0)
+ for rows.Next() {
+ l := &service.ChannelMonitorLatest{}
+ var latency, ping sql.NullInt64
+ if err := rows.Scan(&l.Model, &l.Status, &latency, &ping, &l.CheckedAt); err != nil {
+ return nil, fmt.Errorf("scan latest row: %w", err)
+ }
+ assignNullInt(&l.LatencyMs, latency)
+ assignNullInt(&l.PingLatencyMs, ping)
+ out = append(out, l)
+ }
+ return out, rows.Err()
+}
+
+// assignNullInt 把 sql.NullInt64 解包到 *int 指针目标(valid 才分配新 int)。
+// 集中实现避免 latency / ping 两处重复 if latency.Valid { v := int(...) ... } 模板。
+func assignNullInt(dst **int, n sql.NullInt64) {
+ if !n.Valid {
+ return
+ }
+ v := int(n.Int64)
+ *dst = &v
+}
+
+// ComputeAvailability 计算指定窗口内每个模型的可用率与平均延迟。
+// "可用" = status IN (operational, degraded)。
+//
+// 数据来源:明细表只保留 1 天;窗口前其余天数走聚合表。
+// 明细保留 30 天(monitorHistoryRetentionDays),窗口 <= 30 天时直接扫 histories,
+// 精度到秒,避免与聚合表 UNION 带来的 UTC 日切精度损失。
+func (r *channelMonitorRepository) ComputeAvailability(ctx context.Context, monitorID int64, windowDays int) ([]*service.ChannelMonitorAvailability, error) {
+ if windowDays <= 0 {
+ windowDays = 7
+ }
+ const q = `
+ SELECT model,
+ COUNT(*) AS total,
+ COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok,
+ CASE WHEN COUNT(latency_ms) > 0
+ THEN SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL)::float8 / COUNT(latency_ms)
+ ELSE NULL END AS avg_latency_ms
+ FROM channel_monitor_histories
+ WHERE monitor_id = $1
+ AND checked_at >= NOW() - ($2::int || ' days')::interval
+ GROUP BY model
+ `
+ rows, err := r.db.QueryContext(ctx, q, monitorID, windowDays)
+ if err != nil {
+ return nil, fmt.Errorf("query availability: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ out := make([]*service.ChannelMonitorAvailability, 0)
+ for rows.Next() {
+ row, err := scanAvailabilityRow(rows, windowDays)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, row)
+ }
+ return out, rows.Err()
+}
+
+// scanAvailabilityRow 把单行 (model, total, ok, avg_latency) 扫描为 ChannelMonitorAvailability。
+// 仅服务于 ComputeAvailability(4 列);批量版本因为多一列 monitor_id 直接 inline 调 finalizeAvailabilityRow。
+func scanAvailabilityRow(rows interface{ Scan(...any) error }, windowDays int) (*service.ChannelMonitorAvailability, error) {
+ row := &service.ChannelMonitorAvailability{WindowDays: windowDays}
+ var avgLatency sql.NullFloat64
+ if err := rows.Scan(&row.Model, &row.TotalChecks, &row.OperationalChecks, &avgLatency); err != nil {
+ return nil, fmt.Errorf("scan availability row: %w", err)
+ }
+ finalizeAvailabilityRow(row, avgLatency)
+ return row, nil
+}
+
+// finalizeAvailabilityRow 根据 OperationalChecks/TotalChecks 算出可用率,
+// 并把 sql.NullFloat64 的平均延迟解包为 *int。两处复用避免维护漂移。
+func finalizeAvailabilityRow(row *service.ChannelMonitorAvailability, avgLatency sql.NullFloat64) {
+ if row.TotalChecks > 0 {
+ row.AvailabilityPct = float64(row.OperationalChecks) * 100.0 / float64(row.TotalChecks)
+ }
+ if avgLatency.Valid {
+ v := int(avgLatency.Float64)
+ row.AvgLatencyMs = &v
+ }
+}
+
+// ListLatestForMonitorIDs 一次性查询多个监控的"每个 (monitor_id, model) 最近一条"记录。
+// 利用 PG 的 DISTINCT ON 特性,借助 (monitor_id, model, checked_at DESC) 索引可走 Index Scan。
+func (r *channelMonitorRepository) ListLatestForMonitorIDs(ctx context.Context, ids []int64) (map[int64][]*service.ChannelMonitorLatest, error) {
+ out := make(map[int64][]*service.ChannelMonitorLatest, len(ids))
+ if len(ids) == 0 {
+ return out, nil
+ }
+ const q = `
+ SELECT DISTINCT ON (monitor_id, model)
+ monitor_id, model, status, latency_ms, ping_latency_ms, checked_at
+ FROM channel_monitor_histories
+ WHERE monitor_id = ANY($1)
+ ORDER BY monitor_id, model, checked_at DESC
+ `
+ rows, err := r.db.QueryContext(ctx, q, pq.Array(ids))
+ if err != nil {
+ return nil, fmt.Errorf("query latest batch: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var monitorID int64
+ l := &service.ChannelMonitorLatest{}
+ var latency, ping sql.NullInt64
+ if err := rows.Scan(&monitorID, &l.Model, &l.Status, &latency, &ping, &l.CheckedAt); err != nil {
+ return nil, fmt.Errorf("scan latest batch row: %w", err)
+ }
+ assignNullInt(&l.LatencyMs, latency)
+ assignNullInt(&l.PingLatencyMs, ping)
+ out[monitorID] = append(out[monitorID], l)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// ListRecentHistoryForMonitors 为多个 monitor 批量取各自"指定模型"最近 N 条历史(按 checked_at DESC,最新在前)。
+// primaryModels[monitorID] 指定该监控要过滤的模型名;monitor 不在 primaryModels 中的记录不返回。
+// 通过 CTE + unnest(两个 int8/text 数组) 构造 (monitor_id, model) 白名单,
+// 再用 ROW_NUMBER() OVER (PARTITION BY monitor_id) 取各自前 N 条。
+//
+// 返回值:map[monitorID] -> []*ChannelMonitorHistoryEntry(不含 message,减少网络开销)。
+// 空 ids / 空 primaryModels 返回空 map,不报错。
+func (r *channelMonitorRepository) ListRecentHistoryForMonitors(
+ ctx context.Context,
+ ids []int64,
+ primaryModels map[int64]string,
+ perMonitorLimit int,
+) (map[int64][]*service.ChannelMonitorHistoryEntry, error) {
+ out := make(map[int64][]*service.ChannelMonitorHistoryEntry, len(ids))
+ pairIDs, pairModels := buildMonitorModelPairs(ids, primaryModels)
+ if len(pairIDs) == 0 {
+ return out, nil
+ }
+ perMonitorLimit = clampTimelineLimit(perMonitorLimit)
+
+ const q = `
+ WITH targets AS (
+ SELECT unnest($1::bigint[]) AS monitor_id,
+ unnest($2::text[]) AS model
+ ),
+ ranked AS (
+ SELECT h.monitor_id,
+ h.status,
+ h.latency_ms,
+ h.ping_latency_ms,
+ h.checked_at,
+ ROW_NUMBER() OVER (PARTITION BY h.monitor_id ORDER BY h.checked_at DESC) AS rn
+ FROM channel_monitor_histories h
+ JOIN targets t
+ ON t.monitor_id = h.monitor_id AND t.model = h.model
+ )
+ SELECT monitor_id, status, latency_ms, ping_latency_ms, checked_at
+ FROM ranked
+ WHERE rn <= $3
+ ORDER BY monitor_id, checked_at DESC
+ `
+ rows, err := r.db.QueryContext(ctx, q, pq.Array(pairIDs), pq.Array(pairModels), perMonitorLimit)
+ if err != nil {
+ return nil, fmt.Errorf("query recent history batch: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var monitorID int64
+ entry := &service.ChannelMonitorHistoryEntry{}
+ var latency, ping sql.NullInt64
+ if err := rows.Scan(&monitorID, &entry.Status, &latency, &ping, &entry.CheckedAt); err != nil {
+ return nil, fmt.Errorf("scan recent history row: %w", err)
+ }
+ assignNullInt(&entry.LatencyMs, latency)
+ assignNullInt(&entry.PingLatencyMs, ping)
+ out[monitorID] = append(out[monitorID], entry)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// buildMonitorModelPairs 基于 ids 过滤出有效的 (monitor_id, model) 对,model 为空时跳过。
+// 保证两个数组长度一致且一一对应,供 unnest 展开。
+func buildMonitorModelPairs(ids []int64, primaryModels map[int64]string) ([]int64, []string) {
+ if len(ids) == 0 || len(primaryModels) == 0 {
+ return nil, nil
+ }
+ pairIDs := make([]int64, 0, len(ids))
+ pairModels := make([]string, 0, len(ids))
+ for _, id := range ids {
+ model, ok := primaryModels[id]
+ if !ok || strings.TrimSpace(model) == "" {
+ continue
+ }
+ pairIDs = append(pairIDs, id)
+ pairModels = append(pairModels, model)
+ }
+ return pairIDs, pairModels
+}
+
+// timelineLimit* 批量 timeline 查询的 perMonitorLimit 夹紧范围。
+// 下限 1 表示至少返回最近一条;上限 200 控制单次响应体与 SQL 内存占用(ROW_NUMBER 窗口上限)。
+const (
+ timelineLimitMin = 1
+ timelineLimitMax = 200
+)
+
+// clampTimelineLimit 把 perMonitorLimit 夹紧到 [timelineLimitMin, timelineLimitMax],避免非法值或超大查询。
+func clampTimelineLimit(n int) int {
+ if n < timelineLimitMin {
+ return timelineLimitMin
+ }
+ if n > timelineLimitMax {
+ return timelineLimitMax
+ }
+ return n
+}
+
+// ComputeAvailabilityForMonitors 一次性计算多个监控在某个窗口内的每模型可用率与平均延迟。
+// 明细保留 30 天,直接扫 histories(窗口 <= 30 天时无需聚合)。
+func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*service.ChannelMonitorAvailability, error) {
+ out := make(map[int64][]*service.ChannelMonitorAvailability, len(ids))
+ if len(ids) == 0 {
+ return out, nil
+ }
+ if windowDays <= 0 {
+ windowDays = 7
+ }
+ const q = `
+ SELECT monitor_id,
+ model,
+ COUNT(*) AS total,
+ COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok,
+ CASE WHEN COUNT(latency_ms) > 0
+ THEN SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL)::float8 / COUNT(latency_ms)
+ ELSE NULL END AS avg_latency_ms
+ FROM channel_monitor_histories
+ WHERE monitor_id = ANY($1)
+ AND checked_at >= NOW() - ($2::int || ' days')::interval
+ GROUP BY monitor_id, model
+ `
+ rows, err := r.db.QueryContext(ctx, q, pq.Array(ids), windowDays)
+ if err != nil {
+ return nil, fmt.Errorf("query availability batch: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var monitorID int64
+ row := &service.ChannelMonitorAvailability{WindowDays: windowDays}
+ var avgLatency sql.NullFloat64
+ if err := rows.Scan(&monitorID, &row.Model, &row.TotalChecks, &row.OperationalChecks, &avgLatency); err != nil {
+ return nil, fmt.Errorf("scan availability batch row: %w", err)
+ }
+ // 批量查询多了首列 monitor_id;其余字段的可用率/平均延迟换算与单 monitor 版本一致,
+ // 抽出 finalizeAvailabilityRow 复用,避免两处分别维护除法与 NullFloat 解包。
+ finalizeAvailabilityRow(row, avgLatency)
+ out[monitorID] = append(out[monitorID], row)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// ---------- 聚合维护 ----------
+
+// UpsertDailyRollupsFor 把 targetDate 当天([targetDate, targetDate+1d))的明细
+// 按 (monitor_id, model, bucket_date) 聚合写入 channel_monitor_daily_rollups。
+// - 用 ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE 实现幂等回填,
+// 重复执行只会用最新统计覆盖;
+// - $1::date 让 PG 自动把入参 truncate 到 UTC 日期,调用方不需要预处理 targetDate。
+func (r *channelMonitorRepository) UpsertDailyRollupsFor(ctx context.Context, targetDate time.Time) (int64, error) {
+ const q = `
+ INSERT INTO channel_monitor_daily_rollups (
+ monitor_id, model, bucket_date,
+ total_checks, ok_count,
+ operational_count, degraded_count, failed_count, error_count,
+ sum_latency_ms, count_latency,
+ sum_ping_latency_ms, count_ping_latency,
+ computed_at
+ )
+ SELECT
+ monitor_id,
+ model,
+ $1::date AS bucket_date,
+ COUNT(*) AS total_checks,
+ COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok_count,
+ COUNT(*) FILTER (WHERE status = 'operational') AS operational_count,
+ COUNT(*) FILTER (WHERE status = 'degraded') AS degraded_count,
+ COUNT(*) FILTER (WHERE status = 'failed') AS failed_count,
+ COUNT(*) FILTER (WHERE status = 'error') AS error_count,
+ COALESCE(SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL), 0) AS sum_latency_ms,
+ COUNT(latency_ms) AS count_latency,
+ COALESCE(SUM(ping_latency_ms) FILTER (WHERE ping_latency_ms IS NOT NULL), 0) AS sum_ping_latency_ms,
+ COUNT(ping_latency_ms) AS count_ping_latency,
+ NOW()
+ FROM channel_monitor_histories
+ WHERE checked_at >= $1::date
+ AND checked_at < ($1::date + INTERVAL '1 day')
+ GROUP BY monitor_id, model
+ ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE SET
+ total_checks = EXCLUDED.total_checks,
+ ok_count = EXCLUDED.ok_count,
+ operational_count = EXCLUDED.operational_count,
+ degraded_count = EXCLUDED.degraded_count,
+ failed_count = EXCLUDED.failed_count,
+ error_count = EXCLUDED.error_count,
+ sum_latency_ms = EXCLUDED.sum_latency_ms,
+ count_latency = EXCLUDED.count_latency,
+ sum_ping_latency_ms = EXCLUDED.sum_ping_latency_ms,
+ count_ping_latency = EXCLUDED.count_ping_latency,
+ computed_at = NOW()
+ `
+ res, err := r.db.ExecContext(ctx, q, targetDate)
+ if err != nil {
+ return 0, fmt.Errorf("upsert daily rollups for %s: %w", targetDate.Format("2006-01-02"), err)
+ }
+ n, err := res.RowsAffected()
+ if err != nil {
+ return 0, fmt.Errorf("rows affected (upsert rollups): %w", err)
+ }
+ return n, nil
+}
+
+// DeleteRollupsBefore 物理删 bucket_date < beforeDate 的聚合行,同样分批。
+func (r *channelMonitorRepository) DeleteRollupsBefore(ctx context.Context, beforeDate time.Time) (int64, error) {
+ return deleteChannelMonitorBatched(ctx, r.db, channelMonitorPruneRollupSQL, beforeDate)
+}
+
+// channelMonitorPruneBatchSize 单批删除上限。与 ops_cleanup_service 保持一致的 5000,
+// 在大表上按 id 小批删可以避免长事务和 WAL 堆积。
+const channelMonitorPruneBatchSize = 5000
+
+// channelMonitorPruneHistorySQL 分批物理删明细表过期行。
+const channelMonitorPruneHistorySQL = `
+WITH batch AS (
+ SELECT id FROM channel_monitor_histories
+ WHERE checked_at < $1
+ ORDER BY id
+ LIMIT $2
+)
+DELETE FROM channel_monitor_histories
+WHERE id IN (SELECT id FROM batch)
+`
+
+// channelMonitorPruneRollupSQL 分批物理删 rollup 表过期行。bucket_date 需要 ::date 转型
+// 保证与 DATE 列一致比较。
+const channelMonitorPruneRollupSQL = `
+WITH batch AS (
+ SELECT id FROM channel_monitor_daily_rollups
+ WHERE bucket_date < $1::date
+ ORDER BY id
+ LIMIT $2
+)
+DELETE FROM channel_monitor_daily_rollups
+WHERE id IN (SELECT id FROM batch)
+`
+
+// deleteChannelMonitorBatched 循环执行分批 DELETE,直到影响行为 0。返回累计删除行数。
+// cutoff 由调用方按列类型传入(明细用 time.Time 对 TIMESTAMPTZ,rollup 用 time.Time SQL 侧 ::date 转型)。
+func deleteChannelMonitorBatched(ctx context.Context, db *sql.DB, query string, cutoff time.Time) (int64, error) {
+ var total int64
+ for {
+ res, err := db.ExecContext(ctx, query, cutoff, channelMonitorPruneBatchSize)
+ if err != nil {
+ return total, fmt.Errorf("channel_monitor prune batch: %w", err)
+ }
+ affected, err := res.RowsAffected()
+ if err != nil {
+ return total, fmt.Errorf("channel_monitor prune rows affected: %w", err)
+ }
+ total += affected
+ if affected == 0 {
+ break
+ }
+ }
+ return total, nil
+}
+
+// LoadAggregationWatermark 读 watermark 表(id=1)。
+// watermark 表不是 ent schema(只有一行),直接走原生 SQL。
+// - 行不存在或 last_aggregated_date IS NULL:返回 (nil, nil),由调用方决定首次回填策略
+func (r *channelMonitorRepository) LoadAggregationWatermark(ctx context.Context) (*time.Time, error) {
+ const q = `SELECT last_aggregated_date FROM channel_monitor_aggregation_watermark WHERE id = 1`
+ var t sql.NullTime
+ if err := r.db.QueryRowContext(ctx, q).Scan(&t); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("load aggregation watermark: %w", err)
+ }
+ if !t.Valid {
+ return nil, nil
+ }
+ return &t.Time, nil
+}
+
+// UpdateAggregationWatermark 更新 watermark(UPSERT 到 id=1)。
+// $1::date 让 PG 把入参 truncate 到 UTC 日期,与 last_aggregated_date 列的 DATE 类型一致。
+func (r *channelMonitorRepository) UpdateAggregationWatermark(ctx context.Context, date time.Time) error {
+ const q = `
+ INSERT INTO channel_monitor_aggregation_watermark (id, last_aggregated_date, updated_at)
+ VALUES (1, $1::date, NOW())
+ ON CONFLICT (id) DO UPDATE SET
+ last_aggregated_date = EXCLUDED.last_aggregated_date,
+ updated_at = NOW()
+ `
+ if _, err := r.db.ExecContext(ctx, q, date); err != nil {
+ return fmt.Errorf("update aggregation watermark: %w", err)
+ }
+ return nil
+}
+
+// ---------- helpers ----------
+
+func entToServiceMonitor(row *dbent.ChannelMonitor) *service.ChannelMonitor {
+ if row == nil {
+ return nil
+ }
+ extras := row.ExtraModels
+ if extras == nil {
+ extras = []string{}
+ }
+ headers := row.ExtraHeaders
+ if headers == nil {
+ headers = map[string]string{}
+ }
+ out := &service.ChannelMonitor{
+ ID: row.ID,
+ Name: row.Name,
+ Provider: string(row.Provider),
+ Endpoint: row.Endpoint,
+ APIKey: row.APIKeyEncrypted, // 仍为密文,service 层负责解密
+ PrimaryModel: row.PrimaryModel,
+ ExtraModels: extras,
+ GroupName: row.GroupName,
+ Enabled: row.Enabled,
+ IntervalSeconds: row.IntervalSeconds,
+ LastCheckedAt: row.LastCheckedAt,
+ CreatedBy: row.CreatedBy,
+ CreatedAt: row.CreatedAt,
+ UpdatedAt: row.UpdatedAt,
+ ExtraHeaders: headers,
+ BodyOverrideMode: row.BodyOverrideMode,
+ BodyOverride: row.BodyOverride,
+ }
+ if row.TemplateID != nil {
+ id := *row.TemplateID
+ out.TemplateID = &id
+ }
+ return out
+}
+
+// emptyHeadersIfNilRepo 与 service.emptyHeadersIfNil 功能一致,
+// repo 独立一份避免 import 循环。
+func emptyHeadersIfNilRepo(h map[string]string) map[string]string {
+ if h == nil {
+ return map[string]string{}
+ }
+ return h
+}
+
+// defaultBodyModeRepo 空串归一为 off(同上不循环)。
+func defaultBodyModeRepo(mode string) string {
+ if mode == "" {
+ return "off"
+ }
+ return mode
+}
+
+func emptySliceIfNil(in []string) []string {
+ if in == nil {
+ return []string{}
+ }
+ return in
+}
diff --git a/backend/internal/repository/channel_monitor_template_repo.go b/backend/internal/repository/channel_monitor_template_repo.go
new file mode 100644
index 00000000..845d186b
--- /dev/null
+++ b/backend/internal/repository/channel_monitor_template_repo.go
@@ -0,0 +1,195 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+// channelMonitorRequestTemplateRepository 实现 service.ChannelMonitorRequestTemplateRepository。
+// 与 channelMonitorRepository 分开一个文件,职责清晰。
+type channelMonitorRequestTemplateRepository struct {
+ client *dbent.Client
+ db *sql.DB
+}
+
+// NewChannelMonitorRequestTemplateRepository 创建模板仓储实例。
+func NewChannelMonitorRequestTemplateRepository(client *dbent.Client, db *sql.DB) service.ChannelMonitorRequestTemplateRepository {
+ return &channelMonitorRequestTemplateRepository{client: client, db: db}
+}
+
+// ---------- CRUD ----------
+
+func (r *channelMonitorRequestTemplateRepository) Create(ctx context.Context, t *service.ChannelMonitorRequestTemplate) error {
+ client := clientFromContext(ctx, r.client)
+ builder := client.ChannelMonitorRequestTemplate.Create().
+ SetName(t.Name).
+ SetProvider(channelmonitorrequesttemplate.Provider(t.Provider)).
+ SetDescription(t.Description).
+ SetExtraHeaders(emptyHeadersIfNilRepo(t.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(t.BodyOverrideMode))
+ if t.BodyOverride != nil {
+ builder = builder.SetBodyOverride(t.BodyOverride)
+ }
+
+ created, err := builder.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+ t.ID = created.ID
+ t.CreatedAt = created.CreatedAt
+ t.UpdatedAt = created.UpdatedAt
+ return nil
+}
+
+func (r *channelMonitorRequestTemplateRepository) GetByID(ctx context.Context, id int64) (*service.ChannelMonitorRequestTemplate, error) {
+ row, err := r.client.ChannelMonitorRequestTemplate.Query().
+ Where(channelmonitorrequesttemplate.IDEQ(id)).
+ Only(ctx)
+ if err != nil {
+ return nil, translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+ return entToServiceTemplate(row), nil
+}
+
+func (r *channelMonitorRequestTemplateRepository) Update(ctx context.Context, t *service.ChannelMonitorRequestTemplate) error {
+ client := clientFromContext(ctx, r.client)
+ updater := client.ChannelMonitorRequestTemplate.UpdateOneID(t.ID).
+ SetName(t.Name).
+ SetDescription(t.Description).
+ SetExtraHeaders(emptyHeadersIfNilRepo(t.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(t.BodyOverrideMode))
+ if t.BodyOverride != nil {
+ updater = updater.SetBodyOverride(t.BodyOverride)
+ } else {
+ updater = updater.ClearBodyOverride()
+ }
+ updated, err := updater.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+ t.UpdatedAt = updated.UpdatedAt
+ return nil
+}
+
+func (r *channelMonitorRequestTemplateRepository) Delete(ctx context.Context, id int64) error {
+ client := clientFromContext(ctx, r.client)
+ if err := client.ChannelMonitorRequestTemplate.DeleteOneID(id).Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+ return nil
+}
+
+func (r *channelMonitorRequestTemplateRepository) List(ctx context.Context, params service.ChannelMonitorRequestTemplateListParams) ([]*service.ChannelMonitorRequestTemplate, error) {
+ q := r.client.ChannelMonitorRequestTemplate.Query()
+ if params.Provider != "" {
+ q = q.Where(channelmonitorrequesttemplate.ProviderEQ(channelmonitorrequesttemplate.Provider(params.Provider)))
+ }
+ rows, err := q.
+ Order(dbent.Asc(channelmonitorrequesttemplate.FieldProvider), dbent.Asc(channelmonitorrequesttemplate.FieldName)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list monitor templates: %w", err)
+ }
+ out := make([]*service.ChannelMonitorRequestTemplate, 0, len(rows))
+ for _, row := range rows {
+ out = append(out, entToServiceTemplate(row))
+ }
+ return out, nil
+}
+
+// ApplyToMonitors 把模板当前配置覆盖到 monitorIDs 列表里的关联监控。
+// WHERE 双重过滤:template_id = id AND id IN (monitorIDs),防止用户传了未关联本模板的 id
+// 就被覆盖。走 ent UpdateMany 保留 hooks。
+func (r *channelMonitorRequestTemplateRepository) ApplyToMonitors(ctx context.Context, id int64, monitorIDs []int64) (int64, error) {
+ if len(monitorIDs) == 0 {
+ return 0, nil
+ }
+ client := clientFromContext(ctx, r.client)
+ tpl, err := client.ChannelMonitorRequestTemplate.Query().
+ Where(channelmonitorrequesttemplate.IDEQ(id)).
+ Only(ctx)
+ if err != nil {
+ return 0, translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+
+ updater := client.ChannelMonitor.Update().
+ Where(
+ channelmonitor.TemplateIDEQ(id),
+ channelmonitor.IDIn(monitorIDs...),
+ ).
+ SetExtraHeaders(emptyHeadersIfNilRepo(tpl.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(tpl.BodyOverrideMode))
+ if tpl.BodyOverride != nil {
+ updater = updater.SetBodyOverride(tpl.BodyOverride)
+ } else {
+ updater = updater.ClearBodyOverride()
+ }
+
+ affected, err := updater.Save(ctx)
+ if err != nil {
+ return 0, fmt.Errorf("apply template to monitors: %w", err)
+ }
+ return int64(affected), nil
+}
+
+// CountAssociatedMonitors 统计关联监控数(UI 展示「N 个配置」用)。
+func (r *channelMonitorRequestTemplateRepository) CountAssociatedMonitors(ctx context.Context, id int64) (int64, error) {
+ count, err := r.client.ChannelMonitor.Query().
+ Where(channelmonitor.TemplateIDEQ(id)).
+ Count(ctx)
+ if err != nil {
+ return 0, fmt.Errorf("count monitors for template %d: %w", id, err)
+ }
+ return int64(count), nil
+}
+
+// ListAssociatedMonitors 列出模板关联的所有监控简略字段。
+// ORDER BY name 稳定输出方便前端展示。
+func (r *channelMonitorRequestTemplateRepository) ListAssociatedMonitors(ctx context.Context, id int64) ([]*service.AssociatedMonitorBrief, error) {
+ rows, err := r.client.ChannelMonitor.Query().
+ Where(channelmonitor.TemplateIDEQ(id)).
+ Order(dbent.Asc(channelmonitor.FieldName)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list associated monitors for template %d: %w", id, err)
+ }
+ out := make([]*service.AssociatedMonitorBrief, 0, len(rows))
+ for _, row := range rows {
+ out = append(out, &service.AssociatedMonitorBrief{
+ ID: row.ID,
+ Name: row.Name,
+ Provider: string(row.Provider),
+ Enabled: row.Enabled,
+ })
+ }
+ return out, nil
+}
+
+// ---------- helpers ----------
+
+func entToServiceTemplate(row *dbent.ChannelMonitorRequestTemplate) *service.ChannelMonitorRequestTemplate {
+ if row == nil {
+ return nil
+ }
+ headers := row.ExtraHeaders
+ if headers == nil {
+ headers = map[string]string{}
+ }
+ return &service.ChannelMonitorRequestTemplate{
+ ID: row.ID,
+ Name: row.Name,
+ Provider: string(row.Provider),
+ Description: row.Description,
+ ExtraHeaders: headers,
+ BodyOverrideMode: row.BodyOverrideMode,
+ BodyOverride: row.BodyOverride,
+ CreatedAt: row.CreatedAt,
+ UpdatedAt: row.UpdatedAt,
+ }
+}
diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go
index c17e3365..5e16475a 100644
--- a/backend/internal/repository/group_repo.go
+++ b/backend/internal/repository/group_repo.go
@@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel).
- SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
+ SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
+ SetRpmLimit(groupIn.RPMLimit)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
@@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel).
- SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
+ SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
+ SetRpmLimit(groupIn.RPMLimit)
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
if groupIn.DailyLimitUSD != nil {
diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go
index 9cf3b392..6dbb9fbd 100644
--- a/backend/internal/repository/migrations_runner.go
+++ b/backend/internal/repository/migrations_runner.go
@@ -51,28 +51,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond
const nonTransactionalMigrationSuffix = "_notx.sql"
+const paymentOrdersOutTradeNoUniqueMigration = "120_enforce_payment_orders_out_trade_no_unique_notx.sql"
+const paymentOrdersOutTradeNoUniqueIndex = "paymentorder_out_trade_no_unique"
type migrationChecksumCompatibilityRule struct {
fileChecksum string
acceptedDBChecksum map[string]struct{}
+ acceptedChecksums map[string]struct{}
}
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
-// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。
+// 规则必须同时匹配「迁移名 + 数据库 checksum + 当前文件 checksum」且两者都落在该迁移的已知版本集合内才会放行,
+// 避免放宽全局校验,也允许将误改的历史 migration 回滚为已发布版本而不要求人工修 checksum。
var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{
- "054_drop_legacy_cache_columns.sql": {
- fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
- acceptedDBChecksum: map[string]struct{}{
- "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
- },
- },
- "061_add_usage_log_request_type.sql": {
- fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
- acceptedDBChecksum: map[string]struct{}{
- "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {},
- "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
- },
- },
+ "054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"),
+ "061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"),
+ "109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"),
+ "110_pending_auth_and_provider_default_grants.sql": newMigrationChecksumCompatibilityRule("32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279", "e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925"),
+ "112_add_payment_order_provider_key_snapshot.sql": newMigrationChecksumCompatibilityRule("b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99", "ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e"),
+ "115_auth_identity_legacy_external_backfill.sql": newMigrationChecksumCompatibilityRule("022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"),
+ "116_auth_identity_legacy_external_safety_reports.sql": newMigrationChecksumCompatibilityRule("07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"),
+ "118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227", "a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb"),
+ "119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"),
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"),
+ "123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"),
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
@@ -199,6 +201,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
}
if nonTx {
+ if err := prepareNonTransactionalMigration(ctx, db, name); err != nil {
+ return fmt.Errorf("prepare migration %s: %w", name, err)
+ }
+
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements := splitSQLStatements(content)
@@ -248,6 +254,90 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return nil
}
+func prepareNonTransactionalMigration(ctx context.Context, db *sql.DB, name string) error {
+ switch name {
+ case paymentOrdersOutTradeNoUniqueMigration:
+ return preparePaymentOrdersOutTradeNoUniqueMigration(ctx, db)
+ default:
+ return nil
+ }
+}
+
+func preparePaymentOrdersOutTradeNoUniqueMigration(ctx context.Context, db *sql.DB) error {
+ duplicates, err := findDuplicatePaymentOrderOutTradeNos(ctx, db)
+ if err != nil {
+ return fmt.Errorf("precheck duplicate out_trade_no: %w", err)
+ }
+ if len(duplicates) > 0 {
+ return fmt.Errorf(
+ "duplicate out_trade_no values block %s; remediate duplicates before retrying: %s",
+ paymentOrdersOutTradeNoUniqueMigration,
+ strings.Join(duplicates, ", "),
+ )
+ }
+
+ invalid, err := indexIsInvalid(ctx, db, paymentOrdersOutTradeNoUniqueIndex)
+ if err != nil {
+ return fmt.Errorf("check invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
+ }
+ if !invalid {
+ return nil
+ }
+
+ if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", paymentOrdersOutTradeNoUniqueIndex)); err != nil {
+ return fmt.Errorf("drop invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
+ }
+ return nil
+}
+
+func findDuplicatePaymentOrderOutTradeNos(ctx context.Context, db *sql.DB) ([]string, error) {
+ rows, err := db.QueryContext(ctx, `
+ SELECT out_trade_no, COUNT(*) AS duplicate_count
+ FROM payment_orders
+ WHERE out_trade_no <> ''
+ GROUP BY out_trade_no
+ HAVING COUNT(*) > 1
+ ORDER BY duplicate_count DESC, out_trade_no
+ LIMIT 5
+ `)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ _ = rows.Close()
+ }()
+
+ duplicates := make([]string, 0, 5)
+ for rows.Next() {
+ var outTradeNo string
+ var duplicateCount int
+ if err := rows.Scan(&outTradeNo, &duplicateCount); err != nil {
+ return nil, err
+ }
+ duplicates = append(duplicates, fmt.Sprintf("%s (count=%d)", outTradeNo, duplicateCount))
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return duplicates, nil
+}
+
+func indexIsInvalid(ctx context.Context, db *sql.DB, indexName string) (bool, error) {
+ var invalid bool
+ err := db.QueryRowContext(ctx, `
+ SELECT EXISTS (
+ SELECT 1
+ FROM pg_class idx
+ JOIN pg_namespace ns ON ns.oid = idx.relnamespace
+ JOIN pg_index i ON i.indexrelid = idx.oid
+ WHERE ns.nspname = 'public'
+ AND idx.relname = $1
+ AND NOT i.indisvalid
+ )
+ `, indexName).Scan(&invalid)
+ return invalid, err
+}
+
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
if err != nil {
@@ -322,16 +412,33 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return version, version, hash, nil
}
+func checksumSet(values ...string) map[string]struct{} {
+ out := make(map[string]struct{}, len(values))
+ for _, value := range values {
+ out[value] = struct{}{}
+ }
+ return out
+}
+
+func newMigrationChecksumCompatibilityRule(fileChecksum string, acceptedDBChecksums ...string) migrationChecksumCompatibilityRule {
+ return migrationChecksumCompatibilityRule{
+ fileChecksum: fileChecksum,
+ acceptedDBChecksum: checksumSet(acceptedDBChecksums...),
+ acceptedChecksums: checksumSet(append([]string{fileChecksum}, acceptedDBChecksums...)...),
+ }
+}
+
func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
rule, ok := migrationChecksumCompatibilityRules[name]
if !ok {
return false
}
- if rule.fileChecksum != fileChecksum {
+ _, dbOK := rule.acceptedChecksums[dbChecksum]
+ if !dbOK {
return false
}
- _, ok = rule.acceptedDBChecksum[dbChecksum]
- return ok
+ _, fileOK := rule.acceptedChecksums[fileChecksum]
+ return fileOK
}
func validateMigrationExecutionMode(name, content string) (bool, error) {
diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go
index 6c3ad725..1fcb3be1 100644
--- a/backend/internal/repository/migrations_runner_checksum_test.go
+++ b/backend/internal/repository/migrations_runner_checksum_test.go
@@ -51,4 +51,114 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
)
require.False(t, ok)
})
+
+ t.Run("109历史checksum可兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "109_auth_identity_compat_backfill.sql",
+ "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("109当前checksum可兼容历史checksum", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "109_auth_identity_compat_backfill.sql",
+ "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("109回滚到历史文件后仍兼容已应用的新checksum", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "109_auth_identity_compat_backfill.sql",
+ "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
+ "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("110历史checksum可兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "110_pending_auth_and_provider_default_grants.sql",
+ "e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925",
+ "32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("112历史checksum可兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "112_add_payment_order_provider_key_snapshot.sql",
+ "ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e",
+ "b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("115历史checksum可兼容修复后的legacy external backfill", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "115_auth_identity_legacy_external_backfill.sql",
+ "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f",
+ "022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("116历史checksum可兼容修复后的legacy external safety reports", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "116_auth_identity_legacy_external_safety_reports.sql",
+ "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877",
+ "07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("119历史checksum可兼容占位文件", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "119_enforce_payment_orders_out_trade_no_unique.sql",
+ "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34",
+ "0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("118多个历史checksum都可兼容当前版本", func(t *testing.T) {
+ for _, dbChecksum := range []string{
+ "a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb",
+ "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227",
+ } {
+ ok := isMigrationChecksumCompatible(
+ "118_wechat_dual_mode_and_auth_source_defaults.sql",
+ dbChecksum,
+ "b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0",
+ )
+ require.True(t, ok)
+ }
+ })
+
+ t.Run("120多个历史checksum都可兼容新的notx修复版本", func(t *testing.T) {
+ for _, dbChecksum := range []string{
+ "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61",
+ "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22",
+ "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a",
+ } {
+ ok := isMigrationChecksumCompatible(
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql",
+ dbChecksum,
+ "34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074",
+ )
+ require.True(t, ok)
+ }
+ })
+
+ t.Run("119未知checksum不兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "119_enforce_payment_orders_out_trade_no_unique.sql",
+ "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34",
+ "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
+ )
+ require.False(t, ok)
+ })
}
diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go
index 9f8a94c6..5d67665e 100644
--- a/backend/internal/repository/migrations_runner_extra_test.go
+++ b/backend/internal/repository/migrations_runner_extra_test.go
@@ -94,6 +94,24 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
}
+func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) {
+ for _, name := range []string{
+ "109_auth_identity_compat_backfill.sql",
+ "110_pending_auth_and_provider_default_grants.sql",
+ "112_add_payment_order_provider_key_snapshot.sql",
+ "115_auth_identity_legacy_external_backfill.sql",
+ "116_auth_identity_legacy_external_safety_reports.sql",
+ "118_wechat_dual_mode_and_auth_source_defaults.sql",
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql",
+ "123_fix_legacy_auth_source_grant_on_signup_defaults.sql",
+ } {
+ rule, ok := migrationChecksumCompatibilityRules[name]
+ require.Truef(t, ok, "missing compatibility rule for %s", name)
+ require.NotEmpty(t, rule.fileChecksum)
+ require.NotEmpty(t, rule.acceptedDBChecksum)
+ }
+}
+
func TestEnsureAtlasBaselineAligned(t *testing.T) {
t.Run("skip_when_no_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
diff --git a/backend/internal/repository/migrations_runner_notx_test.go b/backend/internal/repository/migrations_runner_notx_test.go
index db1183cd..b7cb396c 100644
--- a/backend/internal/repository/migrations_runner_notx_test.go
+++ b/backend/internal/repository/migrations_runner_notx_test.go
@@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
require.NoError(t, mock.ExpectationsWereMet())
}
+func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+ mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
+ WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
+ WillReturnError(sql.ErrNoRows)
+ mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
+ WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}).AddRow("dup-out-trade-no", 2))
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
+ Data: []byte(`
+CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
+ ON payment_orders (out_trade_no)
+ WHERE out_trade_no <> '';
+
+DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
+`),
+ },
+ }
+
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "duplicate out_trade_no")
+ require.Contains(t, err.Error(), "dup-out-trade-no")
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+ mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
+ WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
+ WillReturnError(sql.ErrNoRows)
+ mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
+ WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}))
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("paymentorder_out_trade_no_unique").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
+ mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
+ WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql", sqlmock.AnyArg()).
+ WillReturnResult(sqlmock.NewResult(1, 1))
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
+ Data: []byte(`
+CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
+ ON payment_orders (out_trade_no)
+ WHERE out_trade_no <> '';
+
+DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
+`),
+ },
+ }
+
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go
index dd3019bb..eeee5c23 100644
--- a/backend/internal/repository/migrations_schema_integration_test.go
+++ b/backend/internal/repository/migrations_schema_integration_test.go
@@ -89,6 +89,35 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
+func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T) {
+ tx := testTx(t)
+
+ requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false)
+ requireColumn(t, tx, "users", "signup_source", "character varying", 20, false)
+ requireColumnDefaultContains(t, tx, "users", "signup_source", "email")
+ requireConstraintDefinitionContains(
+ t,
+ tx,
+ "users",
+ "users_signup_source_check",
+ "signup_source",
+ "'email'",
+ "'linuxdo'",
+ "'wechat'",
+ "'oidc'",
+ )
+
+ requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE")
+ requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE")
+ requireForeignKeyOnDelete(t, tx, "pending_auth_sessions", "target_user_id", "users", "SET NULL")
+ requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "pending_auth_session_id", "pending_auth_sessions", "CASCADE")
+ requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "identity_id", "auth_identities", "SET NULL")
+
+ requireIndex(t, tx, "payment_orders", "paymentorder_out_trade_no")
+ requirePartialUniqueIndexDefinition(t, tx, "payment_orders", "paymentorder_out_trade_no", "out_trade_no", "WHERE")
+ requireIndexAbsent(t, tx, "payment_orders", "paymentorder_out_trade_no_unique")
+}
+
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
@@ -106,6 +135,118 @@ SELECT EXISTS (
require.True(t, exists, "expected index %s on %s", index, table)
}
+func requireIndexAbsent(t *testing.T, tx *sql.Tx, table, index string) {
+ t.Helper()
+
+ var exists bool
+ err := tx.QueryRowContext(context.Background(), `
+SELECT EXISTS (
+ SELECT 1
+ FROM pg_indexes
+ WHERE schemaname = 'public'
+ AND tablename = $1
+ AND indexname = $2
+)
+`, table, index).Scan(&exists)
+ require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
+ require.False(t, exists, "expected index %s on %s to be absent", index, table)
+}
+
+func requirePartialUniqueIndexDefinition(t *testing.T, tx *sql.Tx, table, index string, fragments ...string) {
+ t.Helper()
+
+ var (
+ unique bool
+ def string
+ )
+
+ err := tx.QueryRowContext(context.Background(), `
+SELECT
+ i.indisunique,
+ pg_get_indexdef(i.indexrelid)
+FROM pg_class idx
+JOIN pg_index i ON i.indexrelid = idx.oid
+JOIN pg_class tbl ON tbl.oid = i.indrelid
+JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
+WHERE ns.nspname = 'public'
+ AND tbl.relname = $1
+ AND idx.relname = $2
+`, table, index).Scan(&unique, &def)
+ require.NoError(t, err, "query index definition for %s.%s", table, index)
+ require.True(t, unique, "expected index %s on %s to be unique", index, table)
+
+ for _, fragment := range fragments {
+ require.Contains(t, def, fragment, "expected index definition for %s.%s to contain %q", table, index, fragment)
+ }
+}
+
+func requireForeignKeyOnDelete(t *testing.T, tx *sql.Tx, table, column, refTable, expected string) {
+ t.Helper()
+
+ var actual string
+ err := tx.QueryRowContext(context.Background(), `
+SELECT CASE c.confdeltype
+ WHEN 'a' THEN 'NO ACTION'
+ WHEN 'r' THEN 'RESTRICT'
+ WHEN 'c' THEN 'CASCADE'
+ WHEN 'n' THEN 'SET NULL'
+ WHEN 'd' THEN 'SET DEFAULT'
+END
+FROM pg_constraint c
+JOIN pg_class tbl ON tbl.oid = c.conrelid
+JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
+JOIN pg_class ref_tbl ON ref_tbl.oid = c.confrelid
+JOIN pg_attribute attr ON attr.attrelid = tbl.oid AND attr.attnum = ANY(c.conkey)
+WHERE ns.nspname = 'public'
+ AND c.contype = 'f'
+ AND tbl.relname = $1
+ AND attr.attname = $2
+ AND ref_tbl.relname = $3
+LIMIT 1
+`, table, column, refTable).Scan(&actual)
+ require.NoError(t, err, "query foreign key action for %s.%s -> %s", table, column, refTable)
+ require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable)
+}
+
+func requireConstraintDefinitionContains(t *testing.T, tx *sql.Tx, table, constraint string, fragments ...string) {
+ t.Helper()
+
+ var def string
+ err := tx.QueryRowContext(context.Background(), `
+SELECT pg_get_constraintdef(c.oid)
+FROM pg_constraint c
+JOIN pg_class tbl ON tbl.oid = c.conrelid
+JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
+WHERE ns.nspname = 'public'
+ AND tbl.relname = $1
+ AND c.conname = $2
+`, table, constraint).Scan(&def)
+ require.NoError(t, err, "query constraint definition for %s.%s", table, constraint)
+
+ for _, fragment := range fragments {
+ require.Contains(t, def, fragment, "expected constraint definition for %s.%s to contain %q", table, constraint, fragment)
+ }
+}
+
+func requireColumnDefaultContains(t *testing.T, tx *sql.Tx, table, column string, fragments ...string) {
+ t.Helper()
+
+ var columnDefault sql.NullString
+ err := tx.QueryRowContext(context.Background(), `
+SELECT column_default
+FROM information_schema.columns
+WHERE table_schema = 'public'
+ AND table_name = $1
+ AND column_name = $2
+`, table, column).Scan(&columnDefault)
+ require.NoError(t, err, "query column_default for %s.%s", table, column)
+ require.True(t, columnDefault.Valid, "expected column_default for %s.%s", table, column)
+
+ for _, fragment := range fragments {
+ require.Contains(t, columnDefault.String, fragment, "expected default for %s.%s to contain %q", table, column, fragment)
+ }
+}
+
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()
diff --git a/backend/internal/repository/openai_403_counter_cache.go b/backend/internal/repository/openai_403_counter_cache.go
new file mode 100644
index 00000000..a68d2518
--- /dev/null
+++ b/backend/internal/repository/openai_403_counter_cache.go
@@ -0,0 +1,51 @@
+package repository
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/redis/go-redis/v9"
+)
+
+const openAI403CounterPrefix = "openai_403_count:account:"
+
+var openAI403CounterIncrScript = redis.NewScript(`
+ local key = KEYS[1]
+ local ttl = tonumber(ARGV[1])
+
+ local count = redis.call('INCR', key)
+ if count == 1 then
+ redis.call('EXPIRE', key, ttl)
+ end
+
+ return count
+`)
+
+type openAI403CounterCache struct {
+ rdb *redis.Client
+}
+
+func NewOpenAI403CounterCache(rdb *redis.Client) service.OpenAI403CounterCache {
+ return &openAI403CounterCache{rdb: rdb}
+}
+
+func (c *openAI403CounterCache) IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
+ key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
+
+ ttlSeconds := windowMinutes * 60
+ if ttlSeconds < 60 {
+ ttlSeconds = 60
+ }
+
+ result, err := openAI403CounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
+ if err != nil {
+ return 0, fmt.Errorf("increment openai 403 count: %w", err)
+ }
+ return result, nil
+}
+
+func (c *openAI403CounterCache) ResetOpenAI403Count(ctx context.Context, accountID int64) error {
+ key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
+ return c.rdb.Del(ctx, key).Err()
+}
diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go
index dca0b612..acb270a3 100644
--- a/backend/internal/repository/openai_oauth_service.go
+++ b/backend/internal/repository/openai_oauth_service.go
@@ -2,6 +2,7 @@ package repository
import (
"context"
+ "errors"
"net/http"
"net/url"
"strings"
@@ -53,6 +54,9 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
Post(s.tokenURL)
if err != nil {
+ if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) {
+ return nil, newOpenAINoProxyHintError(err)
+ }
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
@@ -98,6 +102,9 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre
Post(s.tokenURL)
if err != nil {
+ if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) {
+ return nil, newOpenAINoProxyHintError(err)
+ }
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
@@ -114,3 +121,21 @@ func createOpenAIReqClient(proxyURL string) (*req.Client, error) {
Timeout: 120 * time.Second,
})
}
+
+func shouldReturnOpenAINoProxyHint(ctx context.Context, proxyURL string, err error) bool {
+ if strings.TrimSpace(proxyURL) != "" || err == nil {
+ return false
+ }
+ if ctx != nil && ctx.Err() != nil {
+ return false
+ }
+ return !errors.Is(err, context.Canceled)
+}
+
+func newOpenAINoProxyHintError(cause error) error {
+ return infraerrors.New(
+ http.StatusBadGateway,
+ "OPENAI_OAUTH_PROXY_REQUIRED",
+ "OpenAI OAuth request failed: no proxy is configured and this server could not reach OpenAI directly. Select a proxy that can access OpenAI, then retry; if the authorization code has expired, regenerate the authorization URL.",
+ ).WithCause(cause)
+}
diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go
index c1901d71..b43e2b52 100644
--- a/backend/internal/repository/openai_oauth_service_test.go
+++ b/backend/internal/repository/openai_oauth_service_test.go
@@ -8,6 +8,7 @@ import (
"net/url"
"testing"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@@ -204,6 +205,17 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
require.ErrorContains(s.T(), err, "request failed")
}
+func (s *OpenAIOAuthServiceSuite) TestExchangeCode_RequestErrorWithoutProxyReturnsProxyHint() {
+ s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
+ s.srv.Close()
+
+ _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
+
+ require.Error(s.T(), err)
+ require.Equal(s.T(), "OPENAI_OAUTH_PROXY_REQUIRED", infraerrors.Reason(err))
+ require.Contains(s.T(), infraerrors.Message(err), "no proxy is configured")
+}
+
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
started := make(chan struct{})
block := make(chan struct{})
diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go
index add0e501..590ddaa3 100644
--- a/backend/internal/repository/scheduler_cache.go
+++ b/backend/internal/repository/scheduler_cache.go
@@ -24,6 +24,49 @@ const (
defaultSchedulerSnapshotMGetChunkSize = 128
defaultSchedulerSnapshotWriteChunkSize = 256
+
+ // snapshotGraceTTLSeconds 旧快照过期的宽限期(秒)。
+ // 替代立即 DEL,让正在读取旧版本的 reader 有足够时间完成 ZRANGE。
+ snapshotGraceTTLSeconds = 60
+)
+
+var (
+ // activateSnapshotScript 原子 CAS 切换快照版本。
+ // 仅当新版本号 >= 当前激活版本时才切换,防止并发写入导致版本回滚。
+ // 旧快照使用 EXPIRE 设置宽限期而非立即 DEL,避免与 reader 竞态。
+ //
+ // KEYS[1] = activeKey (sched:active:{bucket})
+ // KEYS[2] = readyKey (sched:ready:{bucket})
+ // KEYS[3] = bucketSetKey (sched:buckets)
+ // KEYS[4] = snapshotKey (新写入的快照 key)
+ // ARGV[1] = 新版本号字符串
+ // ARGV[2] = bucket 字符串 (用于 SADD)
+ // ARGV[3] = 快照 key 前缀 (用于构造旧快照 key)
+ // ARGV[4] = 宽限期 TTL 秒数
+ //
+ // 返回 1 = 已激活, 0 = 版本过旧未激活
+ activateSnapshotScript = redis.NewScript(`
+local currentActive = redis.call('GET', KEYS[1])
+local newVersion = tonumber(ARGV[1])
+
+if currentActive ~= false then
+ local curVersion = tonumber(currentActive)
+ if curVersion and newVersion < curVersion then
+ redis.call('DEL', KEYS[4])
+ return 0
+ end
+end
+
+redis.call('SET', KEYS[1], ARGV[1])
+redis.call('SET', KEYS[2], '1')
+redis.call('SADD', KEYS[3], ARGV[2])
+
+if currentActive ~= false and currentActive ~= ARGV[1] then
+ redis.call('EXPIRE', ARGV[3] .. currentActive, tonumber(ARGV[4]))
+end
+
+return 1
+`)
)
type schedulerCache struct {
@@ -108,9 +151,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
}
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
- activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
- oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
-
+ // Phase 1: 分配新版本号并写入快照数据。
+ // INCR 保证每个调用方获得唯一递增版本号。
+ // 写入的 snapshotKey 是新的版本化 key,reader 尚不知晓,因此无竞态。
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
version, err := c.rdb.Incr(ctx, versionKey).Result()
if err != nil {
@@ -124,7 +167,6 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
return err
}
- pipe := c.rdb.Pipeline()
if len(accounts) > 0 {
// 使用序号作为 score,保持数据库返回的排序语义。
members := make([]redis.Z, 0, len(accounts))
@@ -134,6 +176,7 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
Member: strconv.FormatInt(account.ID, 10),
})
}
+ pipe := c.rdb.Pipeline()
for start := 0; start < len(members); start += c.writeChunkSize {
end := start + c.writeChunkSize
if end > len(members) {
@@ -141,18 +184,25 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
}
pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
}
- } else {
- pipe.Del(ctx, snapshotKey)
- }
- pipe.Set(ctx, activeKey, versionStr, 0)
- pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
- pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
- if _, err := pipe.Exec(ctx); err != nil {
- return err
+ if _, err := pipe.Exec(ctx); err != nil {
+ return err
+ }
}
- if oldActive != "" && oldActive != versionStr {
- _ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
+ // Phase 2: 原子 CAS 激活版本。
+ // Lua 脚本保证:仅当新版本 >= 当前激活版本时才切换 active 指针,
+ // 防止并发写入导致版本回滚。
+ // 旧快照使用 EXPIRE 宽限期而非立即 DEL,避免 reader 竞态。
+ activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
+ readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
+ snapshotKeyPrefix := fmt.Sprintf("%s%d:%s:%s:v", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode)
+
+ keys := []string{activeKey, readyKey, schedulerBucketSetKey, snapshotKey}
+ args := []any{versionStr, bucket.String(), snapshotKeyPrefix, snapshotGraceTTLSeconds}
+
+ _, err = activateSnapshotScript.Run(ctx, c.rdb, keys, args...).Result()
+ if err != nil {
+ return err
}
return nil
@@ -232,6 +282,11 @@ func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.Sched
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
}
+func (c *schedulerCache) UnlockBucket(ctx context.Context, bucket service.SchedulerBucket) error {
+ key := schedulerBucketKey(schedulerLockPrefix, bucket)
+ return c.rdb.Del(ctx, key).Err()
+}
+
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
if err != nil {
@@ -394,11 +449,69 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account {
SessionWindowStart: account.SessionWindowStart,
SessionWindowEnd: account.SessionWindowEnd,
SessionWindowStatus: account.SessionWindowStatus,
+ AccountGroups: filterSchedulerAccountGroups(account.AccountGroups),
+ GroupIDs: filterSchedulerGroupIDs(account.GroupIDs, account.AccountGroups),
Credentials: filterSchedulerCredentials(account.Credentials),
Extra: filterSchedulerExtra(account.Extra),
}
}
+func filterSchedulerAccountGroups(accountGroups []service.AccountGroup) []service.AccountGroup {
+ if len(accountGroups) == 0 {
+ return nil
+ }
+
+ filtered := make([]service.AccountGroup, 0, len(accountGroups))
+ for _, ag := range accountGroups {
+ if ag.GroupID <= 0 {
+ continue
+ }
+ filtered = append(filtered, service.AccountGroup{
+ AccountID: ag.AccountID,
+ GroupID: ag.GroupID,
+ Priority: ag.Priority,
+ CreatedAt: ag.CreatedAt,
+ })
+ }
+ if len(filtered) == 0 {
+ return nil
+ }
+ return filtered
+}
+
+func filterSchedulerGroupIDs(groupIDs []int64, accountGroups []service.AccountGroup) []int64 {
+ if len(groupIDs) == 0 && len(accountGroups) == 0 {
+ return nil
+ }
+
+ seen := make(map[int64]struct{}, len(groupIDs)+len(accountGroups))
+ filtered := make([]int64, 0, len(groupIDs)+len(accountGroups))
+ for _, id := range groupIDs {
+ if id <= 0 {
+ continue
+ }
+ if _, ok := seen[id]; ok {
+ continue
+ }
+ seen[id] = struct{}{}
+ filtered = append(filtered, id)
+ }
+ for _, ag := range accountGroups {
+ if ag.GroupID <= 0 {
+ continue
+ }
+ if _, ok := seen[ag.GroupID]; ok {
+ continue
+ }
+ seen[ag.GroupID] = struct{}{}
+ filtered = append(filtered, ag.GroupID)
+ }
+ if len(filtered) == 0 {
+ return nil
+ }
+ return filtered
+}
+
func filterSchedulerCredentials(credentials map[string]any) map[string]any {
if len(credentials) == 0 {
return nil
diff --git a/backend/internal/repository/scheduler_cache_integration_test.go b/backend/internal/repository/scheduler_cache_integration_test.go
index 134a6a07..948c2c73 100644
--- a/backend/internal/repository/scheduler_cache_integration_test.go
+++ b/backend/internal/repository/scheduler_cache_integration_test.go
@@ -56,6 +56,15 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
SessionWindowStart: &now,
SessionWindowEnd: &windowEnd,
SessionWindowStatus: "active",
+ GroupIDs: []int64{bucket.GroupID},
+ AccountGroups: []service.AccountGroup{
+ {
+ AccountID: 101,
+ GroupID: bucket.GroupID,
+ Priority: 5,
+ Group: &service.Group{ID: bucket.GroupID, Name: "gemini-group"},
+ },
+ },
}
require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account}))
@@ -79,10 +88,17 @@ func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T)
require.Equal(t, 4, got.GetMaxSessions())
require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes())
require.Nil(t, got.Extra["unused_large_field"])
+ require.Equal(t, []int64{bucket.GroupID}, got.GroupIDs)
+ require.Len(t, got.AccountGroups, 1)
+ require.Equal(t, account.ID, got.AccountGroups[0].AccountID)
+ require.Equal(t, bucket.GroupID, got.AccountGroups[0].GroupID)
+ require.Nil(t, got.AccountGroups[0].Group)
full, err := cache.GetAccount(ctx, account.ID)
require.NoError(t, err)
require.NotNil(t, full)
require.Equal(t, "secret-access-token", full.GetCredential("access_token"))
require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob"))
+ require.Len(t, full.AccountGroups, 1)
+ require.NotNil(t, full.AccountGroups[0].Group)
}
diff --git a/backend/internal/repository/scheduler_cache_unit_test.go b/backend/internal/repository/scheduler_cache_unit_test.go
index bcfd0e7a..33f3b581 100644
--- a/backend/internal/repository/scheduler_cache_unit_test.go
+++ b/backend/internal/repository/scheduler_cache_unit_test.go
@@ -31,3 +31,43 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
require.Equal(t, true, got.Extra["mixed_scheduling"])
require.Nil(t, got.Extra["unused_large_field"])
}
+
+func TestBuildSchedulerMetadataAccount_KeepsSlimGroupMembership(t *testing.T) {
+ account := service.Account{
+ ID: 42,
+ Platform: service.PlatformAnthropic,
+ GroupIDs: []int64{7, 9, 7, 0},
+ AccountGroups: []service.AccountGroup{
+ {
+ AccountID: 42,
+ GroupID: 7,
+ Priority: 2,
+ Account: &service.Account{ID: 42, Name: "drop-from-metadata"},
+ Group: &service.Group{ID: 7, Name: "drop-from-metadata"},
+ },
+ {
+ AccountID: 42,
+ GroupID: 11,
+ Priority: 3,
+ Group: &service.Group{ID: 11, Name: "drop-from-metadata"},
+ },
+ {
+ AccountID: 42,
+ GroupID: 0,
+ Priority: 4,
+ },
+ },
+ }
+
+ got := buildSchedulerMetadataAccount(account)
+
+ require.Equal(t, []int64{7, 9, 11}, got.GroupIDs)
+ require.Len(t, got.AccountGroups, 2)
+ require.Equal(t, int64(42), got.AccountGroups[0].AccountID)
+ require.Equal(t, int64(7), got.AccountGroups[0].GroupID)
+ require.Equal(t, 2, got.AccountGroups[0].Priority)
+ require.Nil(t, got.AccountGroups[0].Account)
+ require.Nil(t, got.AccountGroups[0].Group)
+ require.Equal(t, int64(11), got.AccountGroups[1].GroupID)
+ require.Nil(t, got.Groups)
+}
diff --git a/backend/internal/repository/usage_billing_repo.go b/backend/internal/repository/usage_billing_repo.go
index 2b6edad3..62f48b58 100644
--- a/backend/internal/repository/usage_billing_repo.go
+++ b/backend/internal/repository/usage_billing_repo.go
@@ -290,7 +290,6 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
if err != nil {
return nil, err
}
- defer func() { _ = rows.Close() }()
var state service.AccountQuotaState
if rows.Next() {
@@ -299,18 +298,36 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
&state.DailyUsed, &state.DailyLimit,
&state.WeeklyUsed, &state.WeeklyLimit,
); err != nil {
+ _ = rows.Close()
return nil, err
}
} else {
if err := rows.Err(); err != nil {
+ _ = rows.Close()
return nil, err
}
+ _ = rows.Close()
return nil, service.ErrAccountNotFound
}
if err := rows.Err(); err != nil {
+ _ = rows.Close()
return nil, err
}
- if state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit {
+ // 必须在执行下一条 SQL 前显式关闭 rows:pq 驱动在同一连接上
+ // 不允许前一条查询的结果集未耗尽时启动新查询,否则会返回
+ // "unexpected Parse response" 错误。
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ // 任意维度额度在本次递增中从"未超"跨越到"已超"时,必须刷新调度快照,
+ // 否则 Redis 中缓存的 Account 仍显示旧的 used 值,后续请求会继续选中本账号,
+ // 最终观察到 daily_used / weekly_used 大幅超过配置的 limit。
+ // 对于日/周额度,即使本次触发了周期重置(pre=0、post=amount),
+ // 判定式 (post-amount) < limit 同样成立,逻辑与总额度保持一致。
+ crossedTotal := state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit
+ crossedDaily := state.DailyLimit > 0 && state.DailyUsed >= state.DailyLimit && (state.DailyUsed-amount) < state.DailyLimit
+ crossedWeekly := state.WeeklyLimit > 0 && state.WeeklyUsed >= state.WeeklyLimit && (state.WeeklyUsed-amount) < state.WeeklyLimit
+ if crossedTotal || crossedDaily || crossedWeekly {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
return nil, err
diff --git a/backend/internal/repository/usage_billing_repo_integration_test.go b/backend/internal/repository/usage_billing_repo_integration_test.go
index eda34cc9..e8d4d327 100644
--- a/backend/internal/repository/usage_billing_repo_integration_test.go
+++ b/backend/internal/repository/usage_billing_repo_integration_test.go
@@ -199,6 +199,94 @@ func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
require.InDelta(t, 3.5, quotaUsed, 0.000001)
}
+func TestUsageBillingRepositoryApply_EnqueuesSchedulerOutboxOnQuotaCrossing(t *testing.T) {
+ ctx := context.Background()
+ client := testEntClient(t)
+ repo := NewUsageBillingRepository(client, integrationDB)
+
+ newFixture := func(t *testing.T, extra map[string]any) (int64, int64) {
+ t.Helper()
+ user := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("usage-billing-outbox-user-%d-%s@example.com", time.Now().UnixNano(), uuid.NewString()),
+ PasswordHash: "hash",
+ })
+ apiKey := mustCreateApiKey(t, client, &service.APIKey{
+ UserID: user.ID,
+ Key: "sk-usage-billing-outbox-" + uuid.NewString(),
+ Name: "billing-outbox",
+ })
+ account := mustCreateAccount(t, client, &service.Account{
+ Name: "usage-billing-outbox-" + uuid.NewString(),
+ Type: service.AccountTypeAPIKey,
+ Extra: extra,
+ })
+ return apiKey.ID, account.ID
+ }
+
+ outboxCountFor := func(t *testing.T, accountID int64) int {
+ t.Helper()
+ var count int
+ require.NoError(t, integrationDB.QueryRowContext(ctx,
+ "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1 AND account_id = $2",
+ service.SchedulerOutboxEventAccountChanged, accountID,
+ ).Scan(&count))
+ return count
+ }
+
+ t.Run("daily_first_crossing_enqueues", func(t *testing.T) {
+ apiKeyID, accountID := newFixture(t, map[string]any{
+ "quota_daily_limit": 10.0,
+ })
+ // 第一次低于日限额:不应入队 outbox
+ _, err := repo.Apply(ctx, &service.UsageBillingCommand{
+ RequestID: uuid.NewString(),
+ APIKeyID: apiKeyID,
+ AccountID: accountID,
+ AccountType: service.AccountTypeAPIKey,
+ AccountQuotaCost: 4,
+ })
+ require.NoError(t, err)
+ require.Equal(t, 0, outboxCountFor(t, accountID), "below limit should not enqueue")
+
+ // 第二次跨越日限额:应入队一次 outbox
+ _, err = repo.Apply(ctx, &service.UsageBillingCommand{
+ RequestID: uuid.NewString(),
+ APIKeyID: apiKeyID,
+ AccountID: accountID,
+ AccountType: service.AccountTypeAPIKey,
+ AccountQuotaCost: 8,
+ })
+ require.NoError(t, err)
+ require.Equal(t, 1, outboxCountFor(t, accountID), "crossing daily limit should enqueue once")
+
+ // 再次递增(已超):不应重复入队
+ _, err = repo.Apply(ctx, &service.UsageBillingCommand{
+ RequestID: uuid.NewString(),
+ APIKeyID: apiKeyID,
+ AccountID: accountID,
+ AccountType: service.AccountTypeAPIKey,
+ AccountQuotaCost: 2,
+ })
+ require.NoError(t, err)
+ require.Equal(t, 1, outboxCountFor(t, accountID), "subsequent increments beyond limit should not re-enqueue")
+ })
+
+ t.Run("weekly_first_crossing_enqueues", func(t *testing.T) {
+ apiKeyID, accountID := newFixture(t, map[string]any{
+ "quota_weekly_limit": 10.0,
+ })
+ _, err := repo.Apply(ctx, &service.UsageBillingCommand{
+ RequestID: uuid.NewString(),
+ APIKeyID: apiKeyID,
+ AccountID: accountID,
+ AccountType: service.AccountTypeAPIKey,
+ AccountQuotaCost: 15, // 单次即跨越
+ })
+ require.NoError(t, err)
+ require.Equal(t, 1, outboxCountFor(t, accountID), "single-shot crossing weekly limit should enqueue once")
+ })
+}
+
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
ctx := context.Background()
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go
index eca5313f..74d25cb0 100644
--- a/backend/internal/repository/user_group_rate_repo.go
+++ b/backend/internal/repository/user_group_rate_repo.go
@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
sql sqlExecutor
}
-// NewUserGroupRateRepository 创建用户专属分组倍率仓储
+// NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
return &userGroupRateRepository{sql: sqlDB}
}
-// GetByUserID 获取用户的所有专属分组倍率
+// GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目)
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
- query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
+ query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL`
rows, err := r.sql.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return result, nil
}
-// GetByUserIDs 批量获取多个用户的专属分组倍率。
-// 返回结构:map[userID]map[groupID]rate
+// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目)
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
result := make(map[int64]map[int64]float64, len(userIDs))
if len(userIDs) == 0 {
@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
rows, err := r.sql.QueryContext(ctx, `
SELECT user_id, group_id, rate_multiplier
FROM user_group_rate_multipliers
- WHERE user_id = ANY($1)
+ WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL
`, pq.Array(uniqueIDs))
if err != nil {
return nil, err
@@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
return result, nil
}
-// GetByGroupID 获取指定分组下所有用户的专属倍率
+// GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回)
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
query := `
- SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
+ SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override
FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1
@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
var result []service.UserGroupRateEntry
for rows.Next() {
var entry service.UserGroupRateEntry
- if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
+ var rate sql.NullFloat64
+ var rpm sql.NullInt32
+ if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil {
return nil, err
}
+ if rate.Valid {
+ v := rate.Float64
+ entry.RateMultiplier = &v
+ }
+ if rpm.Valid {
+ v := int(rpm.Int32)
+ entry.RPMOverride = &v
+ }
result = append(result, entry)
}
if err := rows.Err(); err != nil {
@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
return result, nil
}
-// GetByUserAndGroup 获取用户在特定分组的专属倍率
+// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil)
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
- var rate float64
+ var rate sql.NullFloat64
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
if err == sql.ErrNoRows {
return nil, nil
@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
if err != nil {
return nil, err
}
- return &rate, nil
+ if !rate.Valid {
+ return nil, nil
+ }
+ v := rate.Float64
+ return &v, nil
}
-// SyncUserGroupRates 同步用户的分组专属倍率
+// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
+func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) {
+ query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
+ var rpm sql.NullInt32
+ err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm)
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ if !rpm.Valid {
+ return nil, nil
+ }
+ v := int(rpm.Int32)
+ return &v, nil
+}
+
+// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
+// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。
+// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。
+// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
if len(rates) == 0 {
- // 如果传入空 map,删除该用户的所有专属倍率
- _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rate_multiplier = NULL, updated_at = NOW()
+ WHERE user_id = $1
+ `, userID); err != nil {
+ return err
+ }
+ _, err := r.sql.ExecContext(ctx,
+ `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`,
+ userID)
return err
}
- // 分离需要删除和需要 upsert 的记录
- var toDelete []int64
+ var clearGroupIDs []int64
upsertGroupIDs := make([]int64, 0, len(rates))
upsertRates := make([]float64, 0, len(rates))
for groupID, rate := range rates {
if rate == nil {
- toDelete = append(toDelete, groupID)
+ clearGroupIDs = append(clearGroupIDs, groupID)
} else {
upsertGroupIDs = append(upsertGroupIDs, groupID)
upsertRates = append(upsertRates, *rate)
}
}
- // 删除指定的记录
- if len(toDelete) > 0 {
+ if len(clearGroupIDs) > 0 {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rate_multiplier = NULL, updated_at = NOW()
+ WHERE user_id = $1 AND group_id = ANY($2)
+ `, userID, pq.Array(clearGroupIDs)); err != nil {
+ return err
+ }
if _, err := r.sql.ExecContext(ctx,
- `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
- userID, pq.Array(toDelete)); err != nil {
+ `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`,
+ userID, pq.Array(clearGroupIDs)); err != nil {
return err
}
}
- // Upsert 记录
- now := time.Now()
if len(upsertGroupIDs) > 0 {
+ now := time.Now()
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
SELECT
@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
return nil
}
-// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
+// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。
+// 语义:
+// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。
+// - 出现的用户行:upsert rate_multiplier。
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
- if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
+ keepUserIDs := make([]int64, 0, len(entries))
+ for _, e := range entries {
+ keepUserIDs = append(keepUserIDs, e.UserID)
+ }
+
+ // 未在 entries 列表中的行:清空 rate_multiplier。
+ if len(keepUserIDs) == 0 {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rate_multiplier = NULL, updated_at = NOW()
+ WHERE group_id = $1
+ `, groupID); err != nil {
+ return err
+ }
+ } else {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rate_multiplier = NULL, updated_at = NOW()
+ WHERE group_id = $1 AND user_id <> ALL($2)
+ `, groupID, pq.Array(keepUserIDs)); err != nil {
+ return err
+ }
+ }
+
+ // 清空后若整行 NULL 则删除。
+ if _, err := r.sql.ExecContext(ctx, `
+ DELETE FROM user_group_rate_multipliers
+ WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
+ `, groupID); err != nil {
return err
}
+
if len(entries) == 0 {
return nil
}
+
userIDs := make([]int64, len(entries))
rates := make([]float64, len(entries))
for i, e := range entries {
@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
return err
}
-// DeleteByGroupID 删除指定分组的所有用户专属倍率
+// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。
+// 语义:
+// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。
+// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
+func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error {
+ keepUserIDs := make([]int64, 0, len(entries))
+ var clearUserIDs []int64
+ upsertUserIDs := make([]int64, 0, len(entries))
+ upsertValues := make([]int32, 0, len(entries))
+ for _, e := range entries {
+ keepUserIDs = append(keepUserIDs, e.UserID)
+ if e.RPMOverride == nil {
+ clearUserIDs = append(clearUserIDs, e.UserID)
+ } else {
+ upsertUserIDs = append(upsertUserIDs, e.UserID)
+ upsertValues = append(upsertValues, int32(*e.RPMOverride))
+ }
+ }
+
+ // 未在 entries 列表中的行:清空 rpm_override。
+ if len(keepUserIDs) == 0 {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rpm_override = NULL, updated_at = NOW()
+ WHERE group_id = $1
+ `, groupID); err != nil {
+ return err
+ }
+ } else {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rpm_override = NULL, updated_at = NOW()
+ WHERE group_id = $1 AND user_id <> ALL($2)
+ `, groupID, pq.Array(keepUserIDs)); err != nil {
+ return err
+ }
+ }
+
+ // 显式 clear 的行。
+ if len(clearUserIDs) > 0 {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rpm_override = NULL, updated_at = NOW()
+ WHERE group_id = $1 AND user_id = ANY($2)
+ `, groupID, pq.Array(clearUserIDs)); err != nil {
+ return err
+ }
+ }
+
+ // 清空后若整行 NULL 则删除。
+ if _, err := r.sql.ExecContext(ctx, `
+ DELETE FROM user_group_rate_multipliers
+ WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
+ `, groupID); err != nil {
+ return err
+ }
+
+ if len(upsertUserIDs) > 0 {
+ now := time.Now()
+ _, err := r.sql.ExecContext(ctx, `
+ INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
+ SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
+ FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
+ ON CONFLICT (user_id, group_id)
+ DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
+ `, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues))
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
+func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rpm_override = NULL, updated_at = NOW()
+ WHERE group_id = $1
+ `, groupID); err != nil {
+ return err
+ }
+ _, err := r.sql.ExecContext(ctx, `
+ DELETE FROM user_group_rate_multipliers
+ WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
+ `, groupID)
+ return err
+}
+
+// DeleteByGroupID 删除指定分组的所有用户专属条目
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
return err
}
-// DeleteByUserID 删除指定用户的所有专属倍率
+// DeleteByUserID 删除指定用户的所有专属条目
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go
new file mode 100644
index 00000000..b2b03746
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo.go
@@ -0,0 +1,880 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "hash/fnv"
+ "reflect"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+ "unsafe"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+var (
+ ErrAuthIdentityOwnershipConflict = infraerrors.Conflict(
+ "AUTH_IDENTITY_OWNERSHIP_CONFLICT",
+ "auth identity already belongs to another user",
+ )
+ ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict(
+ "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
+ "auth identity channel already belongs to another user",
+ )
+ ErrAuthIdentityChannelProviderMismatch = infraerrors.BadRequest(
+ "AUTH_IDENTITY_CHANNEL_PROVIDER_MISMATCH",
+ "auth identity channel provider must match canonical identity",
+ )
+)
+
+type ProviderGrantReason string
+
+const (
+ ProviderGrantReasonSignup ProviderGrantReason = "signup"
+ ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind"
+)
+
+type AuthIdentityKey struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+}
+
+type AuthIdentityChannelKey struct {
+ ProviderType string
+ ProviderKey string
+ Channel string
+ ChannelAppID string
+ ChannelSubject string
+}
+
+type CreateAuthIdentityInput struct {
+ UserID int64
+ Canonical AuthIdentityKey
+ Channel *AuthIdentityChannelKey
+ Issuer *string
+ VerifiedAt *time.Time
+ Metadata map[string]any
+ ChannelMetadata map[string]any
+}
+
+type BindAuthIdentityInput = CreateAuthIdentityInput
+
+type CreateAuthIdentityResult struct {
+ Identity *dbent.AuthIdentity
+ Channel *dbent.AuthIdentityChannel
+}
+
+func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey {
+ if r == nil || r.Identity == nil {
+ return AuthIdentityKey{}
+ }
+ return AuthIdentityKey{
+ ProviderType: r.Identity.ProviderType,
+ ProviderKey: r.Identity.ProviderKey,
+ ProviderSubject: r.Identity.ProviderSubject,
+ }
+}
+
+func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey {
+ if r == nil || r.Channel == nil {
+ return nil
+ }
+ return &AuthIdentityChannelKey{
+ ProviderType: r.Channel.ProviderType,
+ ProviderKey: r.Channel.ProviderKey,
+ Channel: r.Channel.Channel,
+ ChannelAppID: r.Channel.ChannelAppID,
+ ChannelSubject: r.Channel.ChannelSubject,
+ }
+}
+
+type UserAuthIdentityLookup struct {
+ User *dbent.User
+ Identity *dbent.AuthIdentity
+ Channel *dbent.AuthIdentityChannel
+}
+
+type ProviderGrantRecordInput struct {
+ UserID int64
+ ProviderType string
+ GrantReason ProviderGrantReason
+}
+
+type IdentityAdoptionDecisionInput struct {
+ PendingAuthSessionID int64
+ IdentityID *int64
+ AdoptDisplayName bool
+ AdoptAvatar bool
+}
+
+type sqlQueryExecutor interface {
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+}
+
+var repositoryScopedKeyLocks = newScopedKeyLockRegistry()
+
+type scopedKeyLockRegistry struct {
+ mu sync.Mutex
+ locks map[string]*scopedKeyLockEntry
+}
+
+type scopedKeyLockEntry struct {
+ mu sync.Mutex
+ refs int
+}
+
+func newScopedKeyLockRegistry() *scopedKeyLockRegistry {
+ return &scopedKeyLockRegistry{
+ locks: make(map[string]*scopedKeyLockEntry),
+ }
+}
+
+func (r *scopedKeyLockRegistry) lock(keys ...string) func() {
+ normalized := normalizeLockKeys(keys...)
+ if len(normalized) == 0 {
+ return func() {}
+ }
+
+ entries := make([]*scopedKeyLockEntry, 0, len(normalized))
+ r.mu.Lock()
+ for _, key := range normalized {
+ entry := r.locks[key]
+ if entry == nil {
+ entry = &scopedKeyLockEntry{}
+ r.locks[key] = entry
+ }
+ entry.refs++
+ entries = append(entries, entry)
+ }
+ r.mu.Unlock()
+
+ for _, entry := range entries {
+ entry.mu.Lock()
+ }
+
+ return func() {
+ for i := len(entries) - 1; i >= 0; i-- {
+ entries[i].mu.Unlock()
+ }
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ for idx, key := range normalized {
+ entry := entries[idx]
+ entry.refs--
+ if entry.refs == 0 {
+ delete(r.locks, key)
+ }
+ }
+ }
+}
+
+func normalizeLockKeys(keys ...string) []string {
+ if len(keys) == 0 {
+ return nil
+ }
+
+ deduped := make(map[string]struct{}, len(keys))
+ for _, key := range keys {
+ trimmed := strings.TrimSpace(key)
+ if trimmed == "" {
+ continue
+ }
+ deduped[trimmed] = struct{}{}
+ }
+ if len(deduped) == 0 {
+ return nil
+ }
+
+ normalized := make([]string, 0, len(deduped))
+ for key := range deduped {
+ normalized = append(normalized, key)
+ }
+ sort.Strings(normalized)
+ return normalized
+}
+
+func advisoryLockHash(key string) int64 {
+ hasher := fnv.New64a()
+ _, _ = hasher.Write([]byte(key))
+ return int64(hasher.Sum64())
+}
+
+func lockRepositoryScopedKeys(ctx context.Context, client *dbent.Client, exec sqlQueryExecutor, keys ...string) (func(), error) {
+ release := repositoryScopedKeyLocks.lock(keys...)
+ normalized := normalizeLockKeys(keys...)
+ if len(normalized) == 0 || client == nil || exec == nil || client.Driver().Dialect() != dialect.Postgres {
+ return release, nil
+ }
+
+ for _, key := range normalized {
+ rows, err := exec.QueryContext(ctx, "SELECT pg_advisory_xact_lock($1)", advisoryLockHash(key))
+ if err != nil {
+ release()
+ return nil, err
+ }
+ _ = rows.Close()
+ }
+ return release, nil
+}
+
+func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ if dbent.TxFromContext(ctx) != nil {
+ return fn(ctx)
+ }
+
+ tx, err := r.client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := fn(txCtx); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
+ if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
+ return nil, err
+ }
+
+ client := clientFromContext(ctx, r.client)
+
+ create := client.AuthIdentity.Create().
+ SetUserID(input.UserID).
+ SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)).
+ SetMetadata(copyMetadata(input.Metadata)).
+ SetNillableIssuer(input.Issuer).
+ SetNillableVerifiedAt(input.VerifiedAt)
+
+ identity, err := create.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if input.Channel != nil {
+ channel, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
+ SetChannel(strings.TrimSpace(input.Channel.Channel)).
+ SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
+ SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
+ SetMetadata(copyMetadata(input.ChannelMetadata)).
+ Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil
+}
+
+func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) {
+ identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)),
+ ).
+ WithUser().
+ Only(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &UserAuthIdentityLookup{
+ User: identity.Edges.User,
+ Identity: identity,
+ }, nil
+}
+
+func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) {
+ channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
+ authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)),
+ authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)),
+ authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)),
+ ).
+ WithIdentity(func(q *dbent.AuthIdentityQuery) {
+ q.WithUser()
+ }).
+ Only(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &UserAuthIdentityLookup{
+ User: channel.Edges.Identity.Edges.User,
+ Identity: channel.Edges.Identity,
+ Channel: channel,
+ }, nil
+}
+
+func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ identities, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
+ Where(authidentity.UserIDEQ(userID)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ records := make([]service.UserAuthIdentityRecord, 0, len(identities))
+ for _, identity := range identities {
+ if identity == nil {
+ continue
+ }
+ records = append(records, service.UserAuthIdentityRecord{
+ ProviderType: strings.TrimSpace(identity.ProviderType),
+ ProviderKey: strings.TrimSpace(identity.ProviderKey),
+ ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
+ VerifiedAt: identity.VerifiedAt,
+ Issuer: identity.Issuer,
+ Metadata: copyMetadata(identity.Metadata),
+ CreatedAt: identity.CreatedAt,
+ UpdatedAt: identity.UpdatedAt,
+ })
+ }
+
+ return records, nil
+}
+
+func (r *userRepository) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error {
+ provider = strings.ToLower(strings.TrimSpace(provider))
+ if provider == "" || provider == "email" {
+ return service.ErrIdentityProviderInvalid
+ }
+
+ return r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ client := clientFromContext(txCtx, r.client)
+ identityIDs, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ(provider),
+ ).
+ IDs(txCtx)
+ if err != nil {
+ return err
+ }
+ if len(identityIDs) == 0 {
+ return nil
+ }
+
+ if _, err := client.IdentityAdoptionDecision.Update().
+ Where(identityadoptiondecision.IdentityIDIn(identityIDs...)).
+ ClearIdentityID().
+ Save(txCtx); err != nil {
+ return err
+ }
+ if _, err := client.AuthIdentityChannel.Delete().
+ Where(authidentitychannel.IdentityIDIn(identityIDs...)).
+ Exec(txCtx); err != nil {
+ return err
+ }
+ _, err = client.AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ(provider),
+ ).
+ Exec(txCtx)
+ return err
+ })
+}
+
+func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
+ if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
+ return nil, err
+ }
+
+ var result *CreateAuthIdentityResult
+ err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ client := clientFromContext(txCtx, r.client)
+ canonical := input.Canonical
+
+ identityRecords, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
+ authidentity.ProviderKeyIn(compatibleIdentityProviderKeys(canonical.ProviderType, canonical.ProviderKey)...),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
+ ).
+ All(txCtx)
+ if err != nil {
+ return err
+ }
+ identity := selectOwnedCompatibleIdentity(identityRecords, input.UserID)
+ if identity == nil && hasCompatibleIdentityConflict(identityRecords, input.UserID) {
+ return ErrAuthIdentityOwnershipConflict
+ }
+ if identity == nil {
+ identity, err = client.AuthIdentity.Create().
+ SetUserID(input.UserID).
+ SetProviderType(strings.TrimSpace(canonical.ProviderType)).
+ SetProviderKey(strings.TrimSpace(canonical.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)).
+ SetMetadata(copyMetadata(input.Metadata)).
+ SetNillableIssuer(input.Issuer).
+ SetNillableVerifiedAt(input.VerifiedAt).
+ Save(txCtx)
+ if err != nil {
+ return err
+ }
+ } else {
+ targetProviderKey := canonicalizeCompatibleIdentityProviderKey(canonical.ProviderType, identity.ProviderKey, canonical.ProviderKey)
+ update := client.AuthIdentity.UpdateOneID(identity.ID)
+ if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, identity.ProviderKey) {
+ update = update.SetProviderKey(targetProviderKey)
+ }
+ if input.Metadata != nil {
+ update = update.SetMetadata(copyMetadata(input.Metadata))
+ }
+ if input.Issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*input.Issuer))
+ }
+ if input.VerifiedAt != nil {
+ update = update.SetVerifiedAt(*input.VerifiedAt)
+ }
+ identity, err = update.Save(txCtx)
+ if err != nil {
+ return err
+ }
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if input.Channel != nil {
+ channelRecords, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
+ authidentitychannel.ProviderKeyIn(compatibleIdentityProviderKeys(input.Channel.ProviderType, input.Channel.ProviderKey)...),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
+ authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
+ authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
+ ).
+ WithIdentity().
+ All(txCtx)
+ if err != nil {
+ return err
+ }
+ channel = selectOwnedCompatibleChannel(channelRecords, input.UserID)
+ if channel == nil && hasCompatibleChannelConflict(channelRecords, input.UserID) {
+ return ErrAuthIdentityChannelOwnershipConflict
+ }
+ if channel == nil {
+ channel, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
+ SetChannel(strings.TrimSpace(input.Channel.Channel)).
+ SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
+ SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
+ SetMetadata(copyMetadata(input.ChannelMetadata)).
+ Save(txCtx)
+ if err != nil {
+ return err
+ }
+ } else {
+ targetProviderKey := canonicalizeCompatibleIdentityProviderKey(input.Channel.ProviderType, channel.ProviderKey, input.Channel.ProviderKey)
+ update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
+ SetIdentityID(identity.ID)
+ if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, channel.ProviderKey) {
+ update = update.SetProviderKey(targetProviderKey)
+ }
+ if input.ChannelMetadata != nil {
+ update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
+ }
+ channel, err = update.Save(txCtx)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ result = &CreateAuthIdentityResult{Identity: identity, Channel: channel}
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ providerKey = strings.TrimSpace(providerKey)
+ if providerKey == "" {
+ return []string{providerKey}
+ }
+ if providerType != "wechat" {
+ return []string{providerKey}
+ }
+ keys := []string{providerKey}
+ if !strings.EqualFold(providerKey, "wechat-main") {
+ keys = append(keys, "wechat-main")
+ }
+ if !strings.EqualFold(providerKey, "wechat") {
+ keys = append(keys, "wechat")
+ }
+ return keys
+}
+
+func canonicalizeCompatibleIdentityProviderKey(providerType, existingKey, requestedKey string) string {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ existingKey = strings.TrimSpace(existingKey)
+ requestedKey = strings.TrimSpace(requestedKey)
+ if providerType != "wechat" {
+ if requestedKey != "" {
+ return requestedKey
+ }
+ return existingKey
+ }
+ if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
+ return "wechat-main"
+ }
+ if requestedKey != "" {
+ return requestedKey
+ }
+ return existingKey
+}
+
+func compatibleIdentityProviderKeyRank(providerType, providerKey string) int {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ providerKey = strings.TrimSpace(providerKey)
+ if providerType != "wechat" {
+ return 0
+ }
+ switch {
+ case strings.EqualFold(providerKey, "wechat-main"):
+ return 0
+ case strings.EqualFold(providerKey, "wechat"):
+ return 2
+ default:
+ return 1
+ }
+}
+
+func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
+ var selected *dbent.AuthIdentity
+ for _, record := range records {
+ if record.UserID != userID {
+ continue
+ }
+ if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
+ selected = record
+ }
+ }
+ return selected
+}
+
+func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
+ for _, record := range records {
+ if record.UserID != userID {
+ return true
+ }
+ }
+ return false
+}
+
+func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
+ var selected *dbent.AuthIdentityChannel
+ for _, record := range records {
+ if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
+ continue
+ }
+ if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
+ selected = record
+ }
+ }
+ return selected
+}
+
+func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
+ for _, record := range records {
+ if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
+ return true
+ }
+ }
+ return false
+}
+
+func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return false, fmt.Errorf("sql executor is not configured")
+ }
+
+ result, err := exec.ExecContext(ctx, `
+INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
+VALUES ($1, $2, $3)
+ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
+ input.UserID,
+ strings.TrimSpace(input.ProviderType),
+ string(input.GrantReason),
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
+ var result *dbent.IdentityAdoptionDecision
+ err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ client := clientFromContext(txCtx, r.client)
+ releaseLocks, err := lockRepositoryScopedKeys(
+ txCtx,
+ client,
+ txAwareSQLExecutor(txCtx, r.sql, r.client),
+ identityAdoptionDecisionLockKeys(input.PendingAuthSessionID, input.IdentityID)...,
+ )
+ if err != nil {
+ return err
+ }
+ defer releaseLocks()
+
+ if input.IdentityID != nil && *input.IdentityID > 0 {
+ if _, err := client.IdentityAdoptionDecision.Update().
+ Where(
+ identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
+ dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
+ col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
+ s.Where(entsql.Or(
+ entsql.IsNull(col),
+ entsql.NEQ(col, input.PendingAuthSessionID),
+ ))
+ }),
+ ).
+ ClearIdentityID().
+ Save(txCtx); err != nil {
+ return err
+ }
+ }
+
+ create := client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(input.PendingAuthSessionID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar).
+ SetDecidedAt(time.Now().UTC())
+ if input.IdentityID != nil && *input.IdentityID > 0 {
+ create = create.SetIdentityID(*input.IdentityID)
+ }
+
+ decisionID, err := create.
+ OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
+ UpdateNewValues().
+ ID(txCtx)
+ if err != nil {
+ return err
+ }
+
+ result, err = client.IdentityAdoptionDecision.Get(txCtx, decisionID)
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func identityAdoptionDecisionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
+ keys := []string{fmt.Sprintf("identity-adoption:pending:%d", pendingAuthSessionID)}
+ if identityID != nil && *identityID > 0 {
+ keys = append(keys, fmt.Sprintf("identity-adoption:identity:%d", *identityID))
+ }
+ return keys
+}
+
+func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
+ return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)).
+ Only(ctx)
+}
+
+func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error {
+ _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
+ SetLastLoginAt(loginAt).
+ Save(ctx)
+ return err
+}
+
+func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
+ SetLastActiveAt(activeAt).
+ Save(ctx)
+ return err
+}
+
+func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ rows, err := exec.QueryContext(ctx, `
+SELECT storage_provider, storage_key, url, content_type, byte_size, sha256
+FROM user_avatars
+WHERE user_id = $1`, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ return nil, rows.Err()
+ }
+
+ var avatar service.UserAvatar
+ if err := rows.Scan(
+ &avatar.StorageProvider,
+ &avatar.StorageKey,
+ &avatar.URL,
+ &avatar.ContentType,
+ &avatar.ByteSize,
+ &avatar.SHA256,
+ ); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return &avatar, nil
+}
+
+func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
+VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
+ON CONFLICT (user_id) DO UPDATE SET
+ storage_provider = EXCLUDED.storage_provider,
+ storage_key = EXCLUDED.storage_key,
+ url = EXCLUDED.url,
+ content_type = EXCLUDED.content_type,
+ byte_size = EXCLUDED.byte_size,
+ sha256 = EXCLUDED.sha256,
+ updated_at = NOW()`,
+ userID,
+ strings.TrimSpace(input.StorageProvider),
+ strings.TrimSpace(input.StorageKey),
+ strings.TrimSpace(input.URL),
+ strings.TrimSpace(input.ContentType),
+ input.ByteSize,
+ strings.TrimSpace(input.SHA256),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return &service.UserAvatar{
+ StorageProvider: strings.TrimSpace(input.StorageProvider),
+ StorageKey: strings.TrimSpace(input.StorageKey),
+ URL: strings.TrimSpace(input.URL),
+ ContentType: strings.TrimSpace(input.ContentType),
+ ByteSize: input.ByteSize,
+ SHA256: strings.TrimSpace(input.SHA256),
+ }, nil
+}
+
+func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return err
+ }
+ _, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID)
+ return err
+}
+
+func copyMetadata(in map[string]any) map[string]any {
+ if len(in) == 0 {
+ return map[string]any{}
+ }
+ out := make(map[string]any, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func validateAuthIdentityChannelProviderMatch(canonical AuthIdentityKey, channel *AuthIdentityChannelKey) error {
+ if channel == nil {
+ return nil
+ }
+
+ canonicalProviderType := strings.TrimSpace(canonical.ProviderType)
+ canonicalProviderKey := strings.TrimSpace(canonical.ProviderKey)
+ channelProviderType := strings.TrimSpace(channel.ProviderType)
+ channelProviderKey := strings.TrimSpace(channel.ProviderKey)
+
+ if canonicalProviderType != channelProviderType || canonicalProviderKey != channelProviderKey {
+ return ErrAuthIdentityChannelProviderMismatch
+ }
+
+ return nil
+}
+
+func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {
+ return exec
+ }
+ }
+ if fallback != nil {
+ return fallback
+ }
+ return sqlExecutorFromEntClient(client)
+}
+
+func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+ return exec, nil
+}
+
+func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor {
+ if client == nil {
+ return nil
+ }
+
+ clientValue := reflect.ValueOf(client).Elem()
+ configValue := clientValue.FieldByName("config")
+ driverValue := configValue.FieldByName("driver")
+ if !driverValue.IsValid() {
+ return nil
+ }
+
+ driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface()
+ exec, ok := driver.(sqlQueryExecutor)
+ if !ok {
+ return nil
+ }
+ return exec
+}
diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go
new file mode 100644
index 00000000..d4f9e8b3
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go
@@ -0,0 +1,578 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/suite"
+)
+
+type UserProfileIdentityRepoSuite struct {
+ suite.Suite
+ ctx context.Context
+ client *dbent.Client
+ repo *userRepository
+}
+
+func TestUserProfileIdentityRepoSuite(t *testing.T) {
+ suite.Run(t, new(UserProfileIdentityRepoSuite))
+}
+
+func (s *UserProfileIdentityRepoSuite) SetupTest() {
+ s.ctx = context.Background()
+ s.client = testEntClient(s.T())
+ s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
+
+ _, err := integrationDB.ExecContext(s.ctx, `
+TRUNCATE TABLE
+ identity_adoption_decisions,
+ auth_identity_channels,
+ auth_identities,
+ pending_auth_sessions,
+ user_provider_default_grants,
+ user_avatars
+RESTART IDENTITY`)
+ s.Require().NoError(err)
+}
+
+func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User {
+ s.T().Helper()
+
+ user, err := s.client.User.Create().
+ SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())).
+ SetPasswordHash("test-password-hash").
+ SetRole("user").
+ SetStatus("active").
+ Save(s.ctx)
+ s.Require().NoError(err)
+ return user
+}
+
+func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession {
+ s.T().Helper()
+
+ session, err := s.client.PendingAuthSession.Create().
+ SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())).
+ SetIntent("bind_current_user").
+ SetProviderType(key.ProviderType).
+ SetProviderKey(key.ProviderKey).
+ SetProviderSubject(key.ProviderSubject).
+ SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
+ SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}).
+ SetLocalFlowState(map[string]any{"step": "pending"}).
+ Save(s.ctx)
+ s.Require().NoError(err)
+ return session
+}
+
+func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() {
+ user := s.mustCreateUser("canonical-channel")
+
+ verifiedAt := time.Now().UTC().Truncate(time.Second)
+ created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-123",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ Channel: "mp",
+ ChannelAppID: "wx-app",
+ ChannelSubject: "openid-123",
+ },
+ Issuer: stringPtr("https://issuer.example"),
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{"unionid": "union-123"},
+ ChannelMetadata: map[string]any{"openid": "openid-123"},
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(created.Identity)
+ s.Require().NotNil(created.Channel)
+
+ canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef())
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, canonical.User.ID)
+ s.Require().Equal(created.Identity.ID, canonical.Identity.ID)
+ s.Require().Equal("union-123", canonical.Identity.ProviderSubject)
+
+ channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef())
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, channel.User.ID)
+ s.Require().Equal(created.Identity.ID, channel.Identity.ID)
+ s.Require().Equal(created.Channel.ID, channel.Channel.ID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() {
+ owner := s.mustCreateUser("owner")
+ other := s.mustCreateUser("other")
+
+ first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: owner.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ Metadata: map[string]any{"username": "first"},
+ ChannelMetadata: map[string]any{"scope": "read"},
+ })
+ s.Require().NoError(err)
+
+ second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: owner.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ Metadata: map[string]any{"username": "second"},
+ ChannelMetadata: map[string]any{"scope": "write"},
+ })
+ s.Require().NoError(err)
+ s.Require().Equal(first.Identity.ID, second.Identity.ID)
+ s.Require().Equal(first.Channel.ID, second.Channel.ID)
+ s.Require().Equal("second", second.Identity.Metadata["username"])
+ s.Require().Equal("write", second.Channel.Metadata["scope"])
+
+ _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: other.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict)
+
+ _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: other.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-2",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_ReusesLegacyWeChatAliasRecords() {
+ user := s.mustCreateUser("wechat-legacy-alias")
+
+ legacyIdentity, err := s.client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetProviderSubject("union-legacy-123").
+ SetMetadata(map[string]any{"source": "legacy-alias"}).
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ legacyChannel, err := s.client.AuthIdentityChannel.Create().
+ SetIdentityID(legacyIdentity.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetChannel("oa").
+ SetChannelAppID("wx-app-legacy").
+ SetChannelSubject("openid-legacy-123").
+ SetMetadata(map[string]any{"scene": "legacy-alias"}).
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ bound, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-legacy-123",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ Channel: "oa",
+ ChannelAppID: "wx-app-legacy",
+ ChannelSubject: "openid-legacy-123",
+ },
+ Metadata: map[string]any{"source": "canonical-bind"},
+ ChannelMetadata: map[string]any{"scene": "canonical-bind"},
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(bound)
+ s.Require().NotNil(bound.Identity)
+ s.Require().NotNil(bound.Channel)
+ s.Require().Equal(legacyIdentity.ID, bound.Identity.ID)
+ s.Require().Equal(legacyChannel.ID, bound.Channel.ID)
+ s.Require().Equal("wechat-main", bound.Identity.ProviderKey)
+ s.Require().Equal("wechat-main", bound.Channel.ProviderKey)
+ s.Require().Equal("canonical-bind", bound.Identity.Metadata["source"])
+ s.Require().Equal("canonical-bind", bound.Channel.Metadata["scene"])
+
+ identityCount, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderSubjectEQ("union-legacy-123"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(1, identityCount)
+
+ channelCount, err := s.client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ChannelEQ("oa"),
+ authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
+ authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(1, channelCount)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() {
+ user := s.mustCreateUser("provider-mismatch-create")
+
+ _, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-create-mismatch",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "app-mismatch",
+ ChannelSubject: "openid-create-mismatch",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_RejectsChannelProviderMismatch() {
+ user := s.mustCreateUser("provider-mismatch-bind")
+
+ _, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-bind-mismatch",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-legacy",
+ Channel: "oa",
+ ChannelAppID: "wx-app-bind-mismatch",
+ ChannelSubject: "openid-bind-mismatch",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() {
+ user := s.mustCreateUser("tx-rollback")
+ expectedErr := errors.New("rollback")
+
+ err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
+ _, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-rollback",
+ },
+ })
+ s.Require().NoError(err)
+
+ inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "oidc",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+ return expectedErr
+ })
+ s.Require().ErrorIs(err, expectedErr)
+
+ _, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-rollback",
+ })
+ s.Require().True(dbent.IsNotFound(err))
+
+ var count int
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT COUNT(*)
+FROM user_provider_default_grants
+WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`,
+ user.ID,
+ "oidc",
+ string(ProviderGrantReasonFirstBind),
+ ).Scan(&count))
+ s.Require().Zero(count)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() {
+ user := s.mustCreateUser("grant")
+
+ inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+
+ inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().False(inserted)
+
+ inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonSignup,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+
+ var count int
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT COUNT(*)
+FROM user_provider_default_grants
+WHERE user_id = $1 AND provider_type = $2`,
+ user.ID,
+ "wechat",
+ ).Scan(&count))
+ s.Require().Equal(2, count)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() {
+ user := s.mustCreateUser("adoption")
+ identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption",
+ },
+ })
+ s.Require().NoError(err)
+
+ session := s.mustCreatePendingAuthSession(identity.IdentityRef())
+
+ first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ s.Require().NoError(err)
+ s.Require().True(first.AdoptDisplayName)
+ s.Require().False(first.AdoptAvatar)
+ s.Require().Nil(first.IdentityID)
+
+ second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.Identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ })
+ s.Require().NoError(err)
+ s.Require().Equal(first.ID, second.ID)
+ s.Require().NotNil(second.IdentityID)
+ s.Require().Equal(identity.Identity.ID, *second.IdentityID)
+ s.Require().True(second.AdoptAvatar)
+
+ loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID)
+ s.Require().NoError(err)
+ s.Require().Equal(second.ID, loaded.ID)
+ s.Require().Equal(identity.Identity.ID, *loaded.IdentityID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_ReassignsExistingIdentityReference() {
+ user := s.mustCreateUser("adoption-reassign")
+ identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption-reassign",
+ },
+ })
+ s.Require().NoError(err)
+
+ firstSession := s.mustCreatePendingAuthSession(identity.IdentityRef())
+ firstDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: firstSession.ID,
+ IdentityID: &identity.Identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(firstDecision.IdentityID)
+ s.Require().Equal(identity.Identity.ID, *firstDecision.IdentityID)
+
+ secondSession := s.mustCreatePendingAuthSession(identity.IdentityRef())
+ secondDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: secondSession.ID,
+ IdentityID: &identity.Identity.ID,
+ AdoptDisplayName: false,
+ AdoptAvatar: true,
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(secondDecision.IdentityID)
+ s.Require().Equal(identity.Identity.ID, *secondDecision.IdentityID)
+
+ reloadedFirst, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, firstSession.ID)
+ s.Require().NoError(err)
+ s.Require().Nil(reloadedFirst.IdentityID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_AllowsAvatarOnlyProfileUpdate() {
+ user := s.mustCreateUser("avatar-only-update")
+
+ model, err := s.repo.GetByID(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(model)
+
+ err = s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
+ _, err := s.repo.UpsertUserAvatar(txCtx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "remote_url",
+ URL: "https://cdn.example.com/avatar.png",
+ })
+ if err != nil {
+ return err
+ }
+ return s.repo.Update(txCtx, model)
+ })
+ s.Require().NoError(err)
+
+ avatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(avatar)
+ s.Require().Equal("https://cdn.example.com/avatar.png", avatar.URL)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
+ user := s.mustCreateUser("avatar")
+
+ inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "inline",
+ URL: "data:image/png;base64,QUJD",
+ ContentType: "image/png",
+ ByteSize: 3,
+ SHA256: "902fbdd2b1df0c4f70b4a5d23525e932",
+ })
+ s.Require().NoError(err)
+ s.Require().Equal("inline", inlineAvatar.StorageProvider)
+ s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL)
+
+ loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(loadedAvatar)
+ s.Require().Equal("image/png", loadedAvatar.ContentType)
+ s.Require().Equal(3, loadedAvatar.ByteSize)
+
+ _, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "remote_url",
+ URL: "https://cdn.example.com/avatar.png",
+ })
+ s.Require().NoError(err)
+
+ loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(loadedAvatar)
+ s.Require().Equal("remote_url", loadedAvatar.StorageProvider)
+ s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL)
+ s.Require().Zero(loadedAvatar.ByteSize)
+
+ s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID))
+ loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().Nil(loadedAvatar)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() {
+ user := s.mustCreateUser("activity")
+ loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ activeAt := loginAt.Add(5 * time.Minute)
+
+ s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt))
+ s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt))
+
+ var storedLoginAt sqlNullTime
+ var storedActiveAt sqlNullTime
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT last_login_at, last_active_at
+FROM users
+WHERE id = $1`,
+ user.ID,
+ ).Scan(&storedLoginAt, &storedActiveAt))
+ s.Require().True(storedLoginAt.Valid)
+ s.Require().True(storedActiveAt.Valid)
+ s.Require().True(storedLoginAt.Time.Equal(loginAt))
+ s.Require().True(storedActiveAt.Time.Equal(activeAt))
+}
+
+type sqlNullTime struct {
+ Time time.Time
+ Valid bool
+}
+
+func (t *sqlNullTime) Scan(value any) error {
+ switch v := value.(type) {
+ case time.Time:
+ t.Time = v
+ t.Valid = true
+ return nil
+ case nil:
+ t.Time = time.Time{}
+ t.Valid = false
+ return nil
+ default:
+ return fmt.Errorf("unsupported scan type %T", value)
+ }
+}
+
+func stringPtr(v string) *string {
+ return &v
+}
diff --git a/backend/internal/repository/user_profile_identity_repo_unit_test.go b/backend/internal/repository/user_profile_identity_repo_unit_test.go
new file mode 100644
index 00000000..689f32f9
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo_unit_test.go
@@ -0,0 +1,212 @@
+package repository
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias(t *testing.T) {
+ repo, client := newUserEntRepo(t)
+ ctx := context.Background()
+
+ user := &service.User{
+ Email: "wechat-legacy@example.com",
+ Username: "wechat-legacy",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, user))
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetProviderSubject("union-legacy-123").
+ SetMetadata(map[string]any{"source": "legacy-alias"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyChannel, err := client.AuthIdentityChannel.Create().
+ SetIdentityID(legacyIdentity.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetChannel("oa").
+ SetChannelAppID("wx-app-legacy").
+ SetChannelSubject("openid-legacy-123").
+ SetMetadata(map[string]any{"scene": "legacy-alias"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ bound, err := repo.BindAuthIdentityToUser(ctx, BindAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-legacy-123",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ Channel: "oa",
+ ChannelAppID: "wx-app-legacy",
+ ChannelSubject: "openid-legacy-123",
+ },
+ Metadata: map[string]any{"source": "canonical-bind"},
+ ChannelMetadata: map[string]any{"scene": "canonical-bind"},
+ })
+ require.NoError(t, err)
+ require.NotNil(t, bound)
+ require.NotNil(t, bound.Identity)
+ require.NotNil(t, bound.Channel)
+ require.Equal(t, legacyIdentity.ID, bound.Identity.ID)
+ require.Equal(t, legacyChannel.ID, bound.Channel.ID)
+ require.Equal(t, "wechat-main", bound.Identity.ProviderKey)
+ require.Equal(t, "wechat-main", bound.Channel.ProviderKey)
+
+ reloadedIdentity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "wechat-main", reloadedIdentity.ProviderKey)
+ require.Equal(t, "canonical-bind", reloadedIdentity.Metadata["source"])
+
+ reloadedChannel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
+ require.NoError(t, err)
+ require.Equal(t, "wechat-main", reloadedChannel.ProviderKey)
+ require.Equal(t, "canonical-bind", reloadedChannel.Metadata["scene"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderSubjectEQ("union-legacy-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+
+ channelCount, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ChannelEQ("oa"),
+ authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
+ authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, channelCount)
+}
+
+func TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency(t *testing.T) {
+ repo, client := newUserEntRepo(t)
+ ctx := context.Background()
+
+ user := &service.User{
+ Email: "repo-adoption@example.com",
+ Username: "repo-adoption",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, user))
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-repo-adoption").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-repo-adoption").
+ SetIntent("bind_current_user").
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-repo-adoption").
+ SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
+ SetUpstreamIdentityClaims(map[string]any{"provider_subject": "union-repo-adoption"}).
+ SetLocalFlowState(map[string]any{"step": "pending"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ firstCreateStarted := make(chan struct{})
+ releaseFirstCreate := make(chan struct{})
+ var firstCreate sync.Once
+ client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
+ return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
+ blocked := false
+ if m.Op().Is(dbent.OpCreate) {
+ firstCreate.Do(func() {
+ blocked = true
+ close(firstCreateStarted)
+ })
+ }
+ if blocked {
+ <-releaseFirstCreate
+ }
+ return next.Mutate(ctx, m)
+ })
+ })
+
+ type adoptionResult struct {
+ decision *dbent.IdentityAdoptionDecision
+ err error
+ }
+
+ input := IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ }
+
+ results := make(chan adoptionResult, 2)
+ go func() {
+ decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
+ results <- adoptionResult{decision: decision, err: err}
+ }()
+
+ <-firstCreateStarted
+
+ go func() {
+ decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
+ results <- adoptionResult{decision: decision, err: err}
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ close(releaseFirstCreate)
+
+ first := <-results
+ second := <-results
+
+ require.NoError(t, first.err)
+ require.NoError(t, second.err)
+ require.NotNil(t, first.decision)
+ require.NotNil(t, second.decision)
+ require.Equal(t, first.decision.ID, second.decision.ID)
+
+ count, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, count)
+
+ loaded, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, loaded.IdentityID)
+ require.Equal(t, identity.ID, *loaded.IdentityID)
+ require.True(t, loaded.AdoptDisplayName)
+ require.True(t, loaded.AdoptAvatar)
+}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 913e1c40..d1f10cbd 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -11,12 +11,17 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql"
)
@@ -47,12 +52,33 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
var txClient *dbent.Client
+ txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
+ txCtx = dbent.NewTxContext(ctx, tx)
} else {
- // 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
- txClient = r.client
+ // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
+ if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
+ txClient = existingTx.Client()
+ } else {
+ txClient = r.client
+ }
+ }
+
+ releaseEmailLock, err := lockRepositoryScopedKeys(
+ txCtx,
+ txClient,
+ txAwareSQLExecutor(txCtx, r.sql, r.client),
+ normalizedEmailUniquenessLockKey(userIn.Email),
+ )
+ if err != nil {
+ return err
+ }
+ defer releaseEmailLock()
+
+ if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil {
+ return err
}
created, err := txClient.User.Create().
@@ -64,12 +90,19 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
- Save(ctx)
+ SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
+ SetNillableLastLoginAt(userIn.LastLoginAt).
+ SetNillableLastActiveAt(userIn.LastActiveAt).
+ SetRpmLimit(userIn.RPMLimit).
+ Save(txCtx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
- if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
+ if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
+ return err
+ }
+ if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
return err
}
@@ -101,10 +134,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
- m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
+ matches, err := r.client.User.Query().
+ Where(userEmailLookupPredicate(email)).
+ Order(dbent.Asc(dbuser.FieldID)).
+ All(ctx)
if err != nil {
- return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
+ return nil, err
}
+ if len(matches) == 0 {
+ return nil, service.ErrUserNotFound
+ }
+ if len(matches) > 1 {
+ return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email))
+ }
+ m := matches[0]
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
@@ -129,14 +172,41 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
}
var txClient *dbent.Client
+ txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
+ txCtx = dbent.NewTxContext(ctx, tx)
} else {
- // 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
- txClient = r.client
+ // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
+ if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
+ txClient = existingTx.Client()
+ } else {
+ txClient = r.client
+ }
}
+ releaseEmailLock, err := lockRepositoryScopedKeys(
+ txCtx,
+ txClient,
+ txAwareSQLExecutor(txCtx, r.sql, r.client),
+ normalizedEmailUniquenessLockKey(userIn.Email),
+ )
+ if err != nil {
+ return err
+ }
+ defer releaseEmailLock()
+
+ if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil {
+ return err
+ }
+
+ existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ oldEmail := existing.Email
+
updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
SetUsername(userIn.Username).
@@ -150,16 +220,29 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
- SetTotalRecharged(userIn.TotalRecharged)
+ SetTotalRecharged(userIn.TotalRecharged).
+ SetRpmLimit(userIn.RPMLimit)
+ if userIn.SignupSource != "" {
+ updateOp = updateOp.SetSignupSource(userIn.SignupSource)
+ }
+ if userIn.LastLoginAt != nil {
+ updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt)
+ }
+ if userIn.LastActiveAt != nil {
+ updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt)
+ }
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
- updated, err := updateOp.Save(ctx)
+ updated, err := updateOp.Save(txCtx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
- if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
+ if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
+ return err
+ }
+ if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
return err
}
@@ -173,14 +256,146 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return nil
}
+func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
+ client = clientFromContext(ctx, client)
+ if client == nil || userID <= 0 {
+ return nil
+ }
+
+ subject := normalizeEmailAuthIdentitySubject(email)
+ if subject == "" {
+ return nil
+ }
+
+ if err := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(subject).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": source}).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ if !isSQLNoRowsError(err) {
+ return err
+ }
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(subject),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity.UserID != userID {
+ return ErrAuthIdentityOwnershipConflict
+ }
+ return nil
+}
+
+func replaceEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, oldEmail, newEmail string, source string) error {
+ newSubject := normalizeEmailAuthIdentitySubject(newEmail)
+ if err := ensureEmailAuthIdentityWithClient(ctx, client, userID, newEmail, source); err != nil {
+ return err
+ }
+
+ oldSubject := normalizeEmailAuthIdentitySubject(oldEmail)
+ if oldSubject == "" || oldSubject == newSubject {
+ return nil
+ }
+
+ _, err := clientFromContext(ctx, client).AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(oldSubject),
+ ).
+ Exec(ctx)
+ return err
+}
+
+func normalizeEmailAuthIdentitySubject(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" {
+ return ""
+ }
+ if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) {
+ return ""
+ }
+ return normalized
+}
+
func (r *userRepository) Delete(ctx context.Context, id int64) error {
- affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
+ tx, err := r.client.Tx(ctx)
+ if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+
+ var txClient *dbent.Client
+ if err == nil {
+ defer func() { _ = tx.Rollback() }()
+ txClient = tx.Client()
+ } else {
+ if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
+ txClient = existingTx.Client()
+ } else {
+ txClient = r.client
+ }
+ }
+
+ identityIDs, err := txClient.AuthIdentity.Query().
+ Where(authidentity.UserIDEQ(id)).
+ IDs(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ if len(identityIDs) > 0 {
+ if _, err := txClient.IdentityAdoptionDecision.Update().
+ Where(identityadoptiondecision.IdentityIDIn(identityIDs...)).
+ ClearIdentityID().
+ Save(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ if _, err := txClient.AuthIdentityChannel.Delete().
+ Where(authidentitychannel.IdentityIDIn(identityIDs...)).
+ Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ if _, err := txClient.AuthIdentity.Delete().
+ Where(authidentity.UserIDEQ(id)).
+ Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ }
+
+ affected, err := txClient.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if affected == 0 {
return service.ErrUserNotFound
}
+
+ if tx != nil {
+ if err := tx.Commit(); err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ }
return nil
}
@@ -298,8 +513,13 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
+ if sortBy == "last_used_at" {
+ return userLastUsedAtOrder(sortOrder)
+ }
+
var field string
defaultField := true
+ nullsLastField := false
switch sortBy {
case "email":
field = dbuser.FieldEmail
@@ -322,6 +542,10 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
+ case "last_active_at":
+ field = dbuser.FieldLastActiveAt
+ defaultField = false
+ nullsLastField = true
default:
field = dbuser.FieldID
}
@@ -330,14 +554,92 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
}
+ if nullsLastField {
+ return []func(*entsql.Selector){
+ entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(),
+ dbent.Asc(dbuser.FieldID),
+ }
+ }
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
}
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
}
+ if nullsLastField {
+ return []func(*entsql.Selector){
+ entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(),
+ dbent.Desc(dbuser.FieldID),
+ }
+ }
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
+func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ result := make(map[int64]*time.Time, len(userIDs))
+ if len(userIDs) == 0 {
+ return result, nil
+ }
+ if r.sql == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ const query = `
+ SELECT user_id, MAX(created_at) AS last_used_at
+ FROM usage_logs
+ WHERE user_id = ANY($1)
+ GROUP BY user_id
+ `
+
+ rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var (
+ userID int64
+ lastUsedAt time.Time
+ )
+ if scanErr := rows.Scan(&userID, &lastUsedAt); scanErr != nil {
+ return nil, scanErr
+ }
+ ts := lastUsedAt.UTC()
+ result[userID] = &ts
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
+ if err != nil {
+ return nil, err
+ }
+ return latestByUserID[userID], nil
+}
+
+func userLastUsedAtOrder(sortOrder string) []func(*entsql.Selector) {
+ orderExpr := func(direction, nulls string, tieOrder func(string) string) func(*entsql.Selector) {
+ return func(s *entsql.Selector) {
+ subquery := fmt.Sprintf("(SELECT MAX(created_at) FROM usage_logs WHERE user_id = %s)", s.C(dbuser.FieldID))
+ s.OrderExpr(entsql.Expr(subquery + " " + direction + " NULLS " + nulls))
+ s.OrderBy(tieOrder(s.C(dbuser.FieldID)))
+ }
+ }
+
+ if sortOrder == pagination.SortOrderAsc {
+ return []func(*entsql.Selector){
+ orderExpr("ASC", "FIRST", entsql.Asc),
+ }
+ }
+ return []func(*entsql.Selector){
+ orderExpr("DESC", "LAST", entsql.Desc),
+ }
+}
+
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
@@ -436,17 +738,68 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
- return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
+ return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
+}
+
+func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error {
+ client = clientFromContext(ctx, client)
+ if client == nil {
+ return nil
+ }
+
+ matches, err := client.User.Query().
+ Where(userEmailLookupPredicate(email)).
+ All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, match := range matches {
+ if match.ID != userID {
+ return service.ErrEmailExists
+ }
+ }
+ return nil
+}
+
+func userEmailLookupPredicate(email string) predicate.User {
+ normalized := normalizeEmailLookupValue(email)
+ if normalized == "" {
+ return dbuser.EmailEQ(email)
+ }
+ return predicate.User(func(s *entsql.Selector) {
+ s.Where(entsql.P(func(b *entsql.Builder) {
+ b.WriteString("LOWER(TRIM(").
+ Ident(s.C(dbuser.FieldEmail)).
+ WriteString(")) = ").
+ Arg(normalized)
+ }))
+ })
+}
+
+func normalizeEmailLookupValue(email string) string {
+ return strings.ToLower(strings.TrimSpace(email))
+}
+
+func normalizedEmailUniquenessLockKey(email string) string {
+ normalized := normalizeEmailLookupValue(email)
+ if normalized == "" {
+ return ""
+ }
+ return "users:normalized-email:" + normalized
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
- return client.UserAllowedGroup.Create().
+ err := client.UserAllowedGroup.Create().
SetUserID(userID).
SetGroupID(groupID).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx)
+ if isSQLNoRowsError(err) {
+ return nil
+ }
+ return err
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
@@ -546,6 +899,9 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx); err != nil {
+ if isSQLNoRowsError(err) {
+ return nil
+ }
return err
}
}
@@ -558,10 +914,24 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
return
}
dst.ID = src.ID
+ dst.SignupSource = src.SignupSource
+ dst.LastLoginAt = src.LastLoginAt
+ dst.LastActiveAt = src.LastActiveAt
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
+func userSignupSourceOrDefault(signupSource string) string {
+ switch strings.TrimSpace(strings.ToLower(signupSource)) {
+ case "", "email":
+ return "email"
+ case "linuxdo", "wechat", "oidc":
+ return strings.TrimSpace(strings.ToLower(signupSource))
+ default:
+ return "email"
+ }
+}
+
// marshalExtraEmails serializes notify email entries to JSON for storage.
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
return service.MarshalNotifyEmails(entries)
diff --git a/backend/internal/repository/user_repo_email_identity_integration_test.go b/backend/internal/repository/user_repo_email_identity_integration_test.go
new file mode 100644
index 00000000..fddd82c5
--- /dev/null
+++ b/backend/internal/repository/user_repo_email_identity_integration_test.go
@@ -0,0 +1,86 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func (s *UserRepoSuite) TestCreate_CreatesEmailAuthIdentityForNormalEmail() {
+ user := &service.User{
+ Email: "repo-create@example.com",
+ PasswordHash: "test-password-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 2,
+ }
+
+ s.Require().NoError(s.repo.Create(s.ctx, user))
+
+ identity, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("repo-create@example.com"),
+ ).
+ Only(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, identity.UserID)
+}
+
+func (s *UserRepoSuite) TestCreate_SkipsEmailAuthIdentityForSyntheticLinuxDoEmail() {
+ user := &service.User{
+ Email: "linuxdo-legacy-user@linuxdo-connect.invalid",
+ PasswordHash: "test-password-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 2,
+ }
+
+ s.Require().NoError(s.repo.Create(s.ctx, user))
+
+ count, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Zero(count)
+}
+
+func (s *UserRepoSuite) TestUpdate_ReplacesEmailAuthIdentityWhenEmailChanges() {
+ user := s.mustCreateUser(&service.User{
+ Email: "before-update@example.com",
+ })
+
+ user.Email = "after-update@example.com"
+ s.Require().NoError(s.repo.Update(s.ctx, user))
+
+ newIdentity, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("after-update@example.com"),
+ ).
+ Only(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, newIdentity.UserID)
+
+ oldCount, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("before-update@example.com"),
+ ).
+ Count(context.Background())
+ s.Require().NoError(err)
+ s.Require().Zero(oldCount)
+}
diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go
new file mode 100644
index 00000000..7da3db9b
--- /dev/null
+++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go
@@ -0,0 +1,227 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name()))
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+ db.SetMaxOpenConns(10)
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ return newUserRepositoryWithSQL(client, db), client
+}
+
+func TestUserRepositoryGetByEmailNormalizesLegacySpacingAndCase(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ err := repo.Create(ctx, &service.User{
+ Email: " Legacy@Example.com ",
+ Username: "legacy-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.NoError(t, err)
+
+ got, err := repo.GetByEmail(ctx, "legacy@example.com")
+ require.NoError(t, err)
+ require.Equal(t, " Legacy@Example.com ", got.Email)
+}
+
+func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ err := repo.Create(ctx, &service.User{
+ Email: " Legacy@Example.com ",
+ Username: "legacy-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.NoError(t, err)
+
+ exists, err := repo.ExistsByEmail(ctx, " LEGACY@example.com ")
+ require.NoError(t, err)
+ require.True(t, exists)
+}
+
+func TestUserRepositoryCreateRejectsNormalizedEmailDuplicate(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ err := repo.Create(ctx, &service.User{
+ Email: " Existing@Example.com ",
+ Username: "existing-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.NoError(t, err)
+
+ err = repo.Create(ctx, &service.User{
+ Email: "existing@example.com",
+ Username: "duplicate-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.ErrorIs(t, err, service.ErrEmailExists)
+}
+
+func TestUserRepositoryUpdateRejectsNormalizedEmailDuplicate(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ first := &service.User{
+ Email: " Existing@Example.com ",
+ Username: "existing-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, first))
+
+ second := &service.User{
+ Email: "second@example.com",
+ Username: "second-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, second))
+
+ second.Email = " existing@example.com "
+ err := repo.Update(ctx, second)
+ require.ErrorIs(t, err, service.ErrEmailExists)
+}
+
+func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) {
+ repo, client := newUserEntRepo(t)
+ ctx := context.Background()
+
+ _, err := client.User.Create().
+ SetEmail("Conflict@Example.com").
+ SetUsername("conflict-user-1").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.User.Create().
+ SetEmail(" conflict@example.com ").
+ SetUsername("conflict-user-2").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = repo.GetByEmail(ctx, "conflict@example.com")
+ require.Error(t, err)
+ require.ErrorContains(t, err, "normalized email lookup matched multiple users")
+}
+
+func TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency(t *testing.T) {
+ repo, client := newUserEntRepo(t)
+ ctx := context.Background()
+
+ firstCreateStarted := make(chan struct{})
+ releaseFirstCreate := make(chan struct{})
+ var firstCreate sync.Once
+ client.User.Use(func(next dbent.Mutator) dbent.Mutator {
+ return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
+ blocked := false
+ if m.Op().Is(dbent.OpCreate) {
+ firstCreate.Do(func() {
+ blocked = true
+ close(firstCreateStarted)
+ })
+ }
+ if blocked {
+ <-releaseFirstCreate
+ }
+ return next.Mutate(ctx, m)
+ })
+ })
+
+ type createResult struct {
+ err error
+ }
+
+ results := make(chan createResult, 2)
+ go func() {
+ results <- createResult{err: repo.Create(ctx, &service.User{
+ Email: " Race@Example.com ",
+ Username: "race-user-1",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })}
+ }()
+
+ <-firstCreateStarted
+
+ go func() {
+ results <- createResult{err: repo.Create(ctx, &service.User{
+ Email: "race@example.com",
+ Username: "race-user-2",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })}
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ close(releaseFirstCreate)
+
+ first := <-results
+ second := <-results
+
+ errors := []error{first.err, second.err}
+ successes := 0
+ conflicts := 0
+ for _, err := range errors {
+ switch err {
+ case nil:
+ successes++
+ case service.ErrEmailExists:
+ conflicts++
+ default:
+ t.Fatalf("unexpected create error: %v", err)
+ }
+ }
+ require.Equal(t, 1, successes)
+ require.Equal(t, 1, conflicts)
+
+ count, err := client.User.Query().Where(userEmailLookupPredicate("race@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, count)
+}
diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go
index f5d0f9ff..13a605a2 100644
--- a/backend/internal/repository/user_repo_integration_test.go
+++ b/backend/internal/repository/user_repo_integration_test.go
@@ -8,6 +8,8 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
@@ -26,6 +28,8 @@ func (s *UserRepoSuite) SetupTest() {
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
// 清理测试数据,确保每个测试从干净状态开始
+ _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identity_channels")
+ _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identities")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
@@ -122,11 +126,27 @@ func (s *UserRepoSuite) TestGetByEmail() {
s.Require().Equal(user.ID, got.ID)
}
+func (s *UserRepoSuite) TestGetByEmail_NormalizesSpacingAndCaseOnPostgres() {
+ user := s.mustCreateUser(&service.User{Email: " Legacy@Example.com "})
+
+ got, err := s.repo.GetByEmail(s.ctx, " legacy@example.com ")
+ s.Require().NoError(err, "GetByEmail normalized lookup")
+ s.Require().Equal(user.ID, got.ID)
+}
+
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
s.Require().Error(err, "expected error for non-existent email")
}
+func (s *UserRepoSuite) TestExistsByEmail_NormalizesSpacingAndCaseOnPostgres() {
+ s.mustCreateUser(&service.User{Email: " Legacy@Example.com "})
+
+ exists, err := s.repo.ExistsByEmail(s.ctx, " LEGACY@example.com ")
+ s.Require().NoError(err, "ExistsByEmail normalized lookup")
+ s.Require().True(exists)
+}
+
func (s *UserRepoSuite) TestUpdate() {
user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"})
@@ -140,6 +160,30 @@ func (s *UserRepoSuite) TestUpdate() {
s.Require().Equal("updated", updated.Username)
}
+func (s *UserRepoSuite) TestUpdateIgnoresNoRowsFromConflictingEmailIdentityUpsert() {
+ user := s.mustCreateUser(&service.User{Email: "update-existing-identity@test.com", Username: "original"})
+
+ identityCount, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("update-existing-identity@test.com"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(1, identityCount)
+
+ got, err := s.repo.GetByID(s.ctx, user.ID)
+ s.Require().NoError(err)
+ got.Username = "updated"
+ s.Require().NoError(s.repo.Update(s.ctx, got), "Update should tolerate ON CONFLICT DO NOTHING returning no rows")
+
+ updated, err := s.repo.GetByID(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("updated", updated.Username)
+}
+
func (s *UserRepoSuite) TestDelete() {
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
@@ -150,6 +194,39 @@ func (s *UserRepoSuite) TestDelete() {
s.Require().Error(err, "expected error after delete")
}
+func (s *UserRepoSuite) TestDeleteRemovesAuthIdentitiesAndChannels() {
+ user := s.mustCreateUser(&service.User{Email: "delete-oauth@test.com"})
+
+ identity, err := s.client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("delete-oauth-subject").
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ _, err = s.client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetChannel("open").
+ SetChannelAppID("app-id").
+ SetChannelSubject("openid-123").
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ err = s.repo.Delete(s.ctx, user.ID)
+ s.Require().NoError(err)
+
+ identityCount, err := s.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(user.ID)).Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Zero(identityCount)
+
+ channelCount, err := s.client.AuthIdentityChannel.Query().Where(authidentitychannel.IdentityIDEQ(identity.ID)).Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Zero(channelCount)
+}
+
// --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() {
diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go
index ab84b0e9..3a15bc10 100644
--- a/backend/internal/repository/user_repo_sort_integration_test.go
+++ b/backend/internal/repository/user_repo_sort_integration_test.go
@@ -4,11 +4,30 @@ package repository
import (
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
+func (s *UserRepoSuite) mustInsertUsageLog(userID int64, createdAt time.Time) {
+ s.T().Helper()
+
+ account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-log-account"})
+ apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userID})
+
+ _, err := integrationDB.ExecContext(
+ s.ctx,
+ `INSERT INTO usage_logs (user_id, api_key_id, account_id, model, input_tokens, output_tokens, total_cost, actual_cost, created_at)
+ VALUES ($1, $2, $3, 'gpt-test', 1, 1, 0.01, 0.01, $4)`,
+ userID,
+ apiKey.ID,
+ account.ID,
+ createdAt.UTC(),
+ )
+ s.Require().NoError(err)
+}
+
func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() {
s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"})
s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"})
@@ -36,4 +55,110 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
s.Require().Equal(first.ID, users[1].ID)
}
+func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() {
+ lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond)
+ lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ created := s.mustCreateUser(&service.User{
+ Email: "identity-meta@example.com",
+ SignupSource: "linuxdo",
+ LastLoginAt: &lastLoginAt,
+ LastActiveAt: &lastActiveAt,
+ })
+
+ got, err := s.repo.GetByID(s.ctx, created.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("linuxdo", got.SignupSource)
+ s.Require().NotNil(got.LastLoginAt)
+ s.Require().NotNil(got.LastActiveAt)
+ s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
+ s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
+}
+
+func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() {
+ created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"})
+ lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond)
+ lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ created.SignupSource = "oidc"
+ created.LastLoginAt = &lastLoginAt
+ created.LastActiveAt = &lastActiveAt
+
+ s.Require().NoError(s.repo.Update(s.ctx, created))
+
+ got, err := s.repo.GetByID(s.ctx, created.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("oidc", got.SignupSource)
+ s.Require().NotNil(got.LastLoginAt)
+ s.Require().NotNil(got.LastActiveAt)
+ s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
+ s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() {
+ earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond)
+ later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ s.mustCreateUser(&service.User{Email: "nil-active@example.com"})
+ s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later})
+ s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier})
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "last_active_at",
+ SortOrder: "asc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 3)
+ s.Require().Equal("earlier-active@example.com", users[0].Email)
+ s.Require().Equal("later-active@example.com", users[1].Email)
+ s.Require().Equal("nil-active@example.com", users[2].Email)
+}
+
+func (s *UserRepoSuite) TestGetLatestUsedAtByUserIDs_UsesUsageLogs() {
+ older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Second)
+ newer := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Second)
+
+ userWithUsage := s.mustCreateUser(&service.User{Email: "usage-source@example.com"})
+ userWithoutUsage := s.mustCreateUser(&service.User{Email: "usage-missing@example.com"})
+ s.mustInsertUsageLog(userWithUsage.ID, older)
+ s.mustInsertUsageLog(userWithUsage.ID, newer)
+
+ got, err := s.repo.GetLatestUsedAtByUserIDs(s.ctx, []int64{userWithUsage.ID, userWithoutUsage.ID})
+ s.Require().NoError(err)
+ s.Require().Contains(got, userWithUsage.ID)
+ s.Require().NotContains(got, userWithoutUsage.ID)
+ s.Require().NotNil(got[userWithUsage.ID])
+ s.Require().True(got[userWithUsage.ID].Equal(newer))
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByLastUsedAtDesc_UsesUsageLogsNotLastActiveAt() {
+ lastUsedOlder := time.Now().Add(-6 * time.Hour).UTC().Truncate(time.Second)
+ lastUsedNewer := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Second)
+ lastActiveVeryRecent := time.Now().Add(-10 * time.Minute).UTC().Truncate(time.Second)
+
+ nilUsage := s.mustCreateUser(&service.User{Email: "nil-last-used@example.com"})
+ wrongSource := s.mustCreateUser(&service.User{
+ Email: "active-not-usage@example.com",
+ LastActiveAt: &lastActiveVeryRecent,
+ })
+ rightSource := s.mustCreateUser(&service.User{Email: "usage-wins@example.com"})
+
+ s.mustInsertUsageLog(wrongSource.ID, lastUsedOlder)
+ s.mustInsertUsageLog(rightSource.ID, lastUsedNewer)
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "last_used_at",
+ SortOrder: "desc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 3)
+ s.Require().Equal(rightSource.ID, users[0].ID)
+ s.Require().Equal(wrongSource.ID, users[1].ID)
+ s.Require().Equal(nilUsage.ID, users[2].ID)
+}
+
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
diff --git a/backend/internal/repository/user_rpm_cache.go b/backend/internal/repository/user_rpm_cache.go
new file mode 100644
index 00000000..42bf9332
--- /dev/null
+++ b/backend/internal/repository/user_rpm_cache.go
@@ -0,0 +1,108 @@
+package repository
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/redis/go-redis/v9"
+)
+
+// 用户/分组级 RPM 计数器 Redis 实现。
+//
+// 设计说明:
+// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
+// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。
+// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。
+// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。
+// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。
+const (
+ userGroupRPMKeyPrefix = "rpm:ug:"
+ userRPMKeyPrefix = "rpm:u:"
+
+ userRPMKeyTTL = 120 * time.Second
+)
+
+type userRPMCacheImpl struct {
+ rdb *redis.Client
+}
+
+// NewUserRPMCache 创建用户/分组级 RPM 计数器。
+func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache {
+ return &userRPMCacheImpl{rdb: rdb}
+}
+
+// minuteTS 获取当前 Redis 服务端分钟时间戳。
+func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) {
+ t, err := c.rdb.Time(ctx).Result()
+ if err != nil {
+ return 0, fmt.Errorf("redis TIME: %w", err)
+ }
+ return t.Unix() / 60, nil
+}
+
+// atomicIncr 原子 INCR+EXPIRE。
+func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) {
+ pipe := c.rdb.TxPipeline()
+ incr := pipe.Incr(ctx, key)
+ pipe.Expire(ctx, key, userRPMKeyTTL)
+ if _, err := pipe.Exec(ctx); err != nil {
+ return 0, fmt.Errorf("user rpm increment: %w", err)
+ }
+ return int(incr.Val()), nil
+}
+
+// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
+func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
+ minute, err := c.minuteTS(ctx)
+ if err != nil {
+ return 0, err
+ }
+ key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
+ return c.atomicIncr(ctx, key)
+}
+
+// IncrementUserRPM 递增用户分钟计数。
+func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) {
+ minute, err := c.minuteTS(ctx)
+ if err != nil {
+ return 0, err
+ }
+ key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
+ return c.atomicIncr(ctx, key)
+}
+
+// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。
+func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
+ minute, err := c.minuteTS(ctx)
+ if err != nil {
+ return 0, err
+ }
+ key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
+ val, err := c.rdb.Get(ctx, key).Int()
+ if err == redis.Nil {
+ return 0, nil
+ }
+ if err != nil {
+ return 0, fmt.Errorf("user group rpm get: %w", err)
+ }
+ return val, nil
+}
+
+// GetUserRPM 获取用户当前分钟已用 RPM(只读)。
+func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) {
+ minute, err := c.minuteTS(ctx)
+ if err != nil {
+ return 0, err
+ }
+ key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
+ val, err := c.rdb.Get(ctx, key).Int()
+ if err == redis.Nil {
+ return 0, nil
+ }
+ if err != nil {
+ return 0, fmt.Errorf("user rpm get: %w", err)
+ }
+ return val, nil
+}
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index d3adb4a0..f07bbb33 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -89,6 +89,9 @@ var ProviderSet = wire.NewSet(
NewErrorPassthroughRepository,
NewTLSFingerprintProfileRepository,
NewChannelRepository,
+ NewChannelMonitorRepository,
+ NewChannelMonitorRequestTemplateRepository,
+ NewAffiliateRepository,
// Cache implementations
NewGatewayCache,
@@ -96,10 +99,12 @@ var ProviderSet = wire.NewSet(
NewAPIKeyCache,
NewTempUnschedCache,
NewTimeoutCounterCache,
+ NewOpenAI403CounterCache,
NewInternal500CounterCache,
ProvideConcurrencyCache,
ProvideSessionLimitCache,
NewRPMCache,
+ NewUserRPMCache,
NewUserMsgQueueCache,
NewDashboardCache,
NewEmailCache,
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index b686b986..607b93dc 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -50,10 +50,12 @@ func TestAPIContracts(t *testing.T) {
"data": {
"id": 1,
"email": "alice@example.com",
+ "email_bound": true,
"username": "alice",
"role": "user",
"balance": 12.5,
"concurrency": 5,
+ "rpm_limit": 0,
"status": "active",
"allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z",
@@ -63,6 +65,123 @@ func TestAPIContracts(t *testing.T) {
"balance_notify_threshold": null,
"balance_notify_extra_emails": null,
"total_recharged": 0,
+ "linuxdo_bound": false,
+ "oidc_bound": false,
+ "wechat_bound": false,
+ "identities": {
+ "email": {
+ "provider": "email",
+ "provider_key": "email",
+ "bound": true,
+ "bound_count": 1,
+ "can_bind": false,
+ "can_unbind": false,
+ "display_name": "alice@example.com",
+ "subject_hint": "a***e@example.com",
+ "note_key": "profile.authBindings.notes.emailManagedFromProfile",
+ "note": "Primary account email is managed from the profile form."
+ },
+ "linuxdo": {
+ "provider": "linuxdo",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "oidc": {
+ "provider": "oidc",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "wechat": {
+ "provider": "wechat",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ }
+ },
+ "identity_bindings": {
+ "email": {
+ "provider": "email",
+ "provider_key": "email",
+ "bound": true,
+ "bound_count": 1,
+ "can_bind": false,
+ "can_unbind": false,
+ "display_name": "alice@example.com",
+ "subject_hint": "a***e@example.com",
+ "note_key": "profile.authBindings.notes.emailManagedFromProfile",
+ "note": "Primary account email is managed from the profile form."
+ },
+ "linuxdo": {
+ "provider": "linuxdo",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "oidc": {
+ "provider": "oidc",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "wechat": {
+ "provider": "wechat",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ }
+ },
+ "auth_bindings": {
+ "email": {
+ "provider": "email",
+ "provider_key": "email",
+ "bound": true,
+ "bound_count": 1,
+ "can_bind": false,
+ "can_unbind": false,
+ "display_name": "alice@example.com",
+ "subject_hint": "a***e@example.com",
+ "note_key": "profile.authBindings.notes.emailManagedFromProfile",
+ "note": "Primary account email is managed from the profile form."
+ },
+ "linuxdo": {
+ "provider": "linuxdo",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "oidc": {
+ "provider": "oidc",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "wechat": {
+ "provider": "wechat",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ }
+ },
"run_mode": "standard"
}
}`,
@@ -215,6 +334,7 @@ func TestAPIContracts(t *testing.T) {
"fallback_group_id_on_invalid_request": null,
"require_oauth_only": false,
"require_privacy_set": false,
+ "rpm_limit": 0,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
@@ -479,7 +599,7 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyOIDCConnectRedirectURL: "",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
- service.SettingKeyOIDCConnectUsePKCE: "false",
+ service.SettingKeyOIDCConnectUsePKCE: "true",
service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
@@ -500,10 +620,15 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyTableDefaultPageSize: "20",
service.SettingKeyTablePageSizeOptions: "[10,20,50,100]",
- service.SettingKeyOpsMonitoringEnabled: "false",
- service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
- service.SettingKeyOpsQueryModeDefault: "auto",
- service.SettingKeyOpsMetricsIntervalSeconds: "60",
+ service.SettingKeyOpsMonitoringEnabled: "false",
+ service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
+ service.SettingKeyOpsQueryModeDefault: "auto",
+ service.SettingKeyOpsMetricsIntervalSeconds: "60",
+ service.SettingPaymentVisibleMethodAlipaySource: service.VisibleMethodSourceEasyPayAlipay,
+ service.SettingPaymentVisibleMethodWxpaySource: service.VisibleMethodSourceOfficialWechat,
+ service.SettingPaymentVisibleMethodAlipayEnabled: "true",
+ service.SettingPaymentVisibleMethodWxpayEnabled: "false",
+ "openai_advanced_scheduler_enabled": "true",
})
},
method: http.MethodGet,
@@ -549,7 +674,7 @@ func TestAPIContracts(t *testing.T) {
"oidc_connect_redirect_url": "",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post",
- "oidc_connect_use_pkce": false,
+ "oidc_connect_use_pkce": true,
"oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120,
@@ -567,8 +692,34 @@ func TestAPIContracts(t *testing.T) {
"api_base_url": "https://api.example.com",
"contact_info": "support",
"doc_url": "https://docs.example.com",
+ "auth_source_default_email_balance": 0,
+ "auth_source_default_email_concurrency": 5,
+ "auth_source_default_email_subscriptions": [],
+ "auth_source_default_email_grant_on_signup": false,
+ "auth_source_default_email_grant_on_first_bind": false,
+ "auth_source_default_linuxdo_balance": 0,
+ "auth_source_default_linuxdo_concurrency": 5,
+ "auth_source_default_linuxdo_subscriptions": [],
+ "auth_source_default_linuxdo_grant_on_signup": false,
+ "auth_source_default_linuxdo_grant_on_first_bind": false,
+ "auth_source_default_oidc_balance": 0,
+ "auth_source_default_oidc_concurrency": 5,
+ "auth_source_default_oidc_subscriptions": [],
+ "auth_source_default_oidc_grant_on_signup": false,
+ "auth_source_default_oidc_grant_on_first_bind": false,
+ "auth_source_default_wechat_balance": 0,
+ "auth_source_default_wechat_concurrency": 5,
+ "auth_source_default_wechat_subscriptions": [],
+ "auth_source_default_wechat_grant_on_signup": false,
+ "auth_source_default_wechat_grant_on_first_bind": false,
+ "force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
+ "affiliate_rebate_rate": 20,
+ "affiliate_rebate_freeze_hours": 0,
+ "affiliate_rebate_duration_days": 0,
+ "affiliate_rebate_per_invitee_cap": 0,
+ "default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
@@ -589,9 +740,25 @@ func TestAPIContracts(t *testing.T) {
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"enable_cch_signing": false,
+ "enable_anthropic_cache_ttl_1h_injection": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"web_search_emulation_enabled": false,
+ "payment_visible_method_alipay_source": "easypay_alipay",
+ "payment_visible_method_wxpay_source": "official_wxpay",
+ "payment_visible_method_alipay_enabled": true,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": true,
+ "openai_fast_policy_settings": {
+ "rules": [
+ {
+ "service_tier": "priority",
+ "action": "filter",
+ "scope": "all",
+ "fallback_action": "pass"
+ }
+ ]
+ },
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false,
@@ -618,7 +785,239 @@ func TestAPIContracts(t *testing.T) {
"account_quota_notify_enabled": false,
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
- "account_quota_notify_emails": []
+ "account_quota_notify_emails": [],
+ "channel_monitor_enabled": true,
+ "channel_monitor_default_interval_seconds": 60,
+ "available_channels_enabled": false,
+ "affiliate_enabled": false,
+ "wechat_connect_enabled": false,
+ "wechat_connect_app_id": "",
+ "wechat_connect_app_secret_configured": false,
+ "wechat_connect_mode": "open",
+ "wechat_connect_open_enabled": false,
+ "wechat_connect_open_app_id": "",
+ "wechat_connect_open_app_secret_configured": false,
+ "wechat_connect_mp_enabled": false,
+ "wechat_connect_mp_app_id": "",
+ "wechat_connect_mp_app_secret_configured": false,
+ "wechat_connect_mobile_enabled": false,
+ "wechat_connect_mobile_app_id": "",
+ "wechat_connect_mobile_app_secret_configured": false,
+ "wechat_connect_redirect_url": "",
+ "wechat_connect_frontend_redirect_url": "/auth/wechat/callback",
+ "wechat_connect_scopes": "snsapi_login"
+ }
+ }`,
+ },
+ {
+ name: "GET /api/v1/admin/settings falls back to config oauth defaults",
+ setup: func(t *testing.T, deps *contractDeps) {
+ t.Helper()
+ deps.cfg.OIDC = config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "ConfigOIDC",
+ ClientID: "oidc-config-client",
+ ClientSecret: "oidc-config-secret",
+ IssuerURL: "https://issuer.example.com",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ Scopes: "openid email profile",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256,ES256,PS256",
+ ClockSkewSeconds: 120,
+ }
+ deps.cfg.WeChat = config.WeChatConnectConfig{
+ Enabled: true,
+ OpenEnabled: true,
+ OpenAppID: "wx-open-config",
+ OpenAppSecret: "wx-open-secret",
+ Mode: "open",
+ Scopes: "snsapi_login",
+ FrontendRedirectURL: "/auth/wechat/callback",
+ }
+ deps.settingRepo.SetAll(map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyEmailVerifyEnabled: "false",
+ service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
+ })
+ },
+ method: http.MethodGet,
+ path: "/api/v1/admin/settings",
+ wantStatus: http.StatusOK,
+ wantJSON: `{
+ "code": 0,
+ "message": "success",
+ "data": {
+ "registration_enabled": true,
+ "email_verify_enabled": false,
+ "registration_email_suffix_whitelist": [],
+ "promo_code_enabled": true,
+ "password_reset_enabled": false,
+ "frontend_url": "",
+ "invitation_code_enabled": false,
+ "totp_enabled": false,
+ "totp_encryption_key_configured": false,
+ "smtp_host": "",
+ "smtp_port": 587,
+ "smtp_username": "",
+ "smtp_password_configured": false,
+ "smtp_from_email": "",
+ "smtp_from_name": "",
+ "smtp_use_tls": false,
+ "turnstile_enabled": false,
+ "turnstile_site_key": "",
+ "turnstile_secret_key_configured": false,
+ "linuxdo_connect_enabled": false,
+ "linuxdo_connect_client_id": "",
+ "linuxdo_connect_client_secret_configured": false,
+ "linuxdo_connect_redirect_url": "",
+ "oidc_connect_enabled": true,
+ "oidc_connect_provider_name": "ConfigOIDC",
+ "oidc_connect_client_id": "oidc-config-client",
+ "oidc_connect_client_secret_configured": true,
+ "oidc_connect_issuer_url": "https://issuer.example.com",
+ "oidc_connect_discovery_url": "",
+ "oidc_connect_authorize_url": "",
+ "oidc_connect_token_url": "",
+ "oidc_connect_userinfo_url": "",
+ "oidc_connect_jwks_url": "",
+ "oidc_connect_scopes": "openid email profile",
+ "oidc_connect_redirect_url": "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ "oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
+ "oidc_connect_token_auth_method": "client_secret_post",
+ "oidc_connect_use_pkce": true,
+ "oidc_connect_validate_id_token": true,
+ "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
+ "oidc_connect_clock_skew_seconds": 120,
+ "oidc_connect_require_email_verified": false,
+ "oidc_connect_userinfo_email_path": "",
+ "oidc_connect_userinfo_id_path": "",
+ "oidc_connect_userinfo_username_path": "",
+ "site_name": "Sub2API",
+ "site_logo": "",
+ "site_subtitle": "Subscription to API Conversion Platform",
+ "api_base_url": "",
+ "contact_info": "",
+ "doc_url": "",
+ "home_content": "",
+ "hide_ccs_import_button": false,
+ "purchase_subscription_enabled": false,
+ "purchase_subscription_url": "",
+ "table_default_page_size": 20,
+ "table_page_size_options": [10, 20, 50],
+ "custom_menu_items": [],
+ "custom_endpoints": [],
+ "default_concurrency": 0,
+ "default_balance": 0,
+ "affiliate_rebate_rate": 20,
+ "affiliate_rebate_freeze_hours": 0,
+ "affiliate_rebate_duration_days": 0,
+ "affiliate_rebate_per_invitee_cap": 0,
+ "default_user_rpm_limit": 0,
+ "default_subscriptions": [],
+ "enable_model_fallback": false,
+ "fallback_model_anthropic": "claude-3-5-sonnet-20241022",
+ "fallback_model_openai": "gpt-4o",
+ "fallback_model_gemini": "gemini-2.5-pro",
+ "fallback_model_antigravity": "gemini-2.5-pro",
+ "enable_identity_patch": true,
+ "identity_patch_prompt": "",
+ "ops_monitoring_enabled": false,
+ "ops_realtime_monitoring_enabled": true,
+ "ops_query_mode_default": "auto",
+ "ops_metrics_interval_seconds": 60,
+ "min_claude_code_version": "",
+ "max_claude_code_version": "",
+ "allow_ungrouped_key_scheduling": false,
+ "backend_mode_enabled": false,
+ "enable_fingerprint_unification": true,
+ "enable_metadata_passthrough": false,
+ "enable_cch_signing": false,
+ "enable_anthropic_cache_ttl_1h_injection": false,
+ "web_search_emulation_enabled": false,
+ "payment_visible_method_alipay_source": "",
+ "payment_visible_method_wxpay_source": "",
+ "payment_visible_method_alipay_enabled": false,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": false,
+ "openai_fast_policy_settings": {
+ "rules": [
+ {
+ "service_tier": "priority",
+ "action": "filter",
+ "scope": "all",
+ "fallback_action": "pass"
+ }
+ ]
+ },
+ "payment_enabled": false,
+ "payment_min_amount": 0,
+ "payment_max_amount": 0,
+ "payment_daily_limit": 0,
+ "payment_order_timeout_minutes": 0,
+ "payment_max_pending_orders": 0,
+ "payment_enabled_types": null,
+ "payment_balance_disabled": false,
+ "payment_balance_recharge_multiplier": 0,
+ "payment_recharge_fee_rate": 0,
+ "payment_load_balance_strategy": "",
+ "payment_product_name_prefix": "",
+ "payment_product_name_suffix": "",
+ "payment_help_image_url": "",
+ "payment_help_text": "",
+ "payment_cancel_rate_limit_enabled": false,
+ "payment_cancel_rate_limit_max": 0,
+ "payment_cancel_rate_limit_window": 0,
+ "payment_cancel_rate_limit_unit": "",
+ "payment_cancel_rate_limit_window_mode": "",
+ "balance_low_notify_enabled": false,
+ "account_quota_notify_enabled": false,
+ "balance_low_notify_threshold": 0,
+ "balance_low_notify_recharge_url": "",
+ "account_quota_notify_emails": [],
+ "channel_monitor_enabled": true,
+ "channel_monitor_default_interval_seconds": 60,
+ "available_channels_enabled": false,
+ "affiliate_enabled": false,
+ "wechat_connect_enabled": true,
+ "wechat_connect_app_id": "wx-open-config",
+ "wechat_connect_app_secret_configured": true,
+ "wechat_connect_mode": "open",
+ "wechat_connect_open_enabled": true,
+ "wechat_connect_open_app_id": "wx-open-config",
+ "wechat_connect_open_app_secret_configured": true,
+ "wechat_connect_mp_enabled": false,
+ "wechat_connect_mp_app_id": "wx-open-config",
+ "wechat_connect_mp_app_secret_configured": true,
+ "wechat_connect_mobile_enabled": false,
+ "wechat_connect_mobile_app_id": "wx-open-config",
+ "wechat_connect_mobile_app_secret_configured": true,
+ "wechat_connect_redirect_url": "",
+ "wechat_connect_frontend_redirect_url": "/auth/wechat/callback",
+ "wechat_connect_scopes": "snsapi_login",
+ "auth_source_default_email_balance": 0,
+ "auth_source_default_email_concurrency": 5,
+ "auth_source_default_email_subscriptions": [],
+ "auth_source_default_email_grant_on_signup": false,
+ "auth_source_default_email_grant_on_first_bind": false,
+ "auth_source_default_linuxdo_balance": 0,
+ "auth_source_default_linuxdo_concurrency": 5,
+ "auth_source_default_linuxdo_subscriptions": [],
+ "auth_source_default_linuxdo_grant_on_signup": false,
+ "auth_source_default_linuxdo_grant_on_first_bind": false,
+ "auth_source_default_oidc_balance": 0,
+ "auth_source_default_oidc_concurrency": 5,
+ "auth_source_default_oidc_subscriptions": [],
+ "auth_source_default_oidc_grant_on_signup": false,
+ "auth_source_default_oidc_grant_on_first_bind": false,
+ "auth_source_default_wechat_balance": 0,
+ "auth_source_default_wechat_concurrency": 5,
+ "auth_source_default_wechat_subscriptions": [],
+ "auth_source_default_wechat_grant_on_signup": false,
+ "auth_source_default_wechat_grant_on_first_bind": false,
+ "force_email_on_third_party_signup": false
}
}`,
},
@@ -665,6 +1064,7 @@ func TestAPIContracts(t *testing.T) {
type contractDeps struct {
now time.Time
router http.Handler
+ cfg *config.Config
apiKeyRepo *stubApiKeyRepo
groupRepo *stubGroupRepo
userSubRepo *stubUserSubscriptionRepo
@@ -726,7 +1126,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
- adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
@@ -785,6 +1185,7 @@ func newContractDeps(t *testing.T) *contractDeps {
return &contractDeps{
now: now,
router: r,
+ cfg: cfg,
apiKeyRepo: apiKeyRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
@@ -858,6 +1259,18 @@ func (r *stubUserRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
+func (r *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (r *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (r *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ return errors.New("not implemented")
+}
+
func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -894,6 +1307,26 @@ func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
return errors.New("not implemented")
}
+func (r *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (r *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
+ return errors.New("not implemented")
+}
+
+func (r *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (r *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (r *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ return nil
+}
+
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go
index ed2578c8..dde92dfd 100644
--- a/backend/internal/server/middleware/admin_auth_test.go
+++ b/backend/internal/server/middleware/admin_auth_test.go
@@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -19,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
- authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
+ authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,
@@ -153,6 +154,18 @@ func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
+func (s *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (s *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
@@ -161,6 +174,18 @@ func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
+func (s *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserIDs call")
+}
+
+func (s *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserID call")
+}
+
+func (s *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ panic("unexpected UpdateUserLastActiveAt call")
+}
+
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
@@ -189,6 +214,14 @@ func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
+func (s *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ panic("unexpected ListUserAuthIdentities call")
+}
+
+func (s *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
+ panic("unexpected UnbindUserAuthProvider call")
+}
+
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/server/middleware/backend_mode_guard.go b/backend/internal/server/middleware/backend_mode_guard.go
index 46482af3..ae53037e 100644
--- a/backend/internal/server/middleware/backend_mode_guard.go
+++ b/backend/internal/server/middleware/backend_mode_guard.go
@@ -27,23 +27,50 @@ func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFun
}
}
+func backendModeAllowsAuthPath(path string) bool {
+ path = strings.ToLower(strings.TrimSpace(path))
+ for _, suffix := range []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} {
+ if strings.HasSuffix(path, suffix) {
+ return true
+ }
+ }
+
+ for _, suffix := range []string{
+ "/auth/oauth/linuxdo/callback",
+ "/auth/oauth/wechat/callback",
+ "/auth/oauth/wechat/payment/callback",
+ "/auth/oauth/oidc/callback",
+ "/auth/oauth/linuxdo/complete-registration",
+ "/auth/oauth/wechat/complete-registration",
+ "/auth/oauth/oidc/complete-registration",
+ "/auth/oauth/linuxdo/create-account",
+ "/auth/oauth/wechat/create-account",
+ "/auth/oauth/oidc/create-account",
+ "/auth/oauth/linuxdo/bind-login",
+ "/auth/oauth/wechat/bind-login",
+ "/auth/oauth/oidc/bind-login",
+ } {
+ if strings.HasSuffix(path, suffix) {
+ return true
+ }
+ }
+
+ return strings.Contains(path, "/auth/oauth/pending/")
+}
+
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
-// Allows: login, login/2fa, logout, refresh (admin needs these).
-// Blocks: register, forgot-password, reset-password, OAuth, etc.
+// Allows the minimal auth surface admins still need in backend mode, including
+// OAuth callbacks and pending continuations. Handler-level backend mode checks
+// still enforce admin-only login and forbid self-service registration.
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
- path := c.Request.URL.Path
- // Allow login, 2FA, logout, refresh, public settings
- allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
- for _, suffix := range allowedSuffixes {
- if strings.HasSuffix(path, suffix) {
- c.Next()
- return
- }
+ if backendModeAllowsAuthPath(c.Request.URL.Path) {
+ c.Next()
+ return
}
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
c.Abort()
diff --git a/backend/internal/server/middleware/backend_mode_guard_test.go b/backend/internal/server/middleware/backend_mode_guard_test.go
index 8878ebc9..bd77677b 100644
--- a/backend/internal/server/middleware/backend_mode_guard_test.go
+++ b/backend/internal/server/middleware/backend_mode_guard_test.go
@@ -198,6 +198,96 @@ func TestBackendModeAuthGuard(t *testing.T) {
path: "/api/v1/auth/refresh",
wantStatus: http.StatusOK,
},
+ {
+ name: "enabled_blocks_linuxdo_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/linuxdo/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_linuxdo_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/linuxdo/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_blocks_wechat_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_wechat_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_blocks_wechat_payment_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/payment/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_wechat_payment_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/payment/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_blocks_oidc_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/oidc/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_oidc_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/oidc/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_oauth_pending_exchange",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/pending/exchange",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_oauth_pending_send_verify_code",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/pending/send-verify-code",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_oauth_pending_create_account",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/pending/create-account",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_oauth_pending_bind_login",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/pending/bind-login",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_provider_bind_login",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/oidc/bind-login",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_provider_create_account",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/create-account",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_legacy_complete_registration",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/linuxdo/complete-registration",
+ wantStatus: http.StatusOK,
+ },
{
name: "enabled_blocks_register",
enabled: "true",
diff --git a/backend/internal/server/middleware/jwt_auth.go b/backend/internal/server/middleware/jwt_auth.go
index 4aceb355..48cb9004 100644
--- a/backend/internal/server/middleware/jwt_auth.go
+++ b/backend/internal/server/middleware/jwt_auth.go
@@ -1,6 +1,7 @@
package middleware
import (
+ "context"
"errors"
"strings"
@@ -11,11 +12,19 @@ import (
// NewJWTAuthMiddleware 创建 JWT 认证中间件
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
- return JWTAuthMiddleware(jwtAuth(authService, userService))
+ return JWTAuthMiddleware(jwtAuth(authService, userService, userService))
+}
+
+type jwtUserReader interface {
+ GetByID(ctx context.Context, id int64) (*service.User, error)
+}
+
+type userActivityToucher interface {
+ TouchLastActiveForUser(ctx context.Context, user *service.User)
}
// jwtAuth JWT认证中间件实现
-func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
+func jwtAuth(authService *service.AuthService, userService jwtUserReader, activityToucher userActivityToucher) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization header中提取token
authHeader := c.GetHeader("Authorization")
@@ -73,6 +82,9 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
+ if activityToucher != nil {
+ activityToucher.TouchLastActiveForUser(c.Request.Context(), user)
+ }
c.Next()
}
diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go
index c483a51e..a643d3bc 100644
--- a/backend/internal/server/middleware/jwt_auth_test.go
+++ b/backend/internal/server/middleware/jwt_auth_test.go
@@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -30,6 +31,25 @@ func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, e
return u, nil
}
+func (r *stubJWTUserRepo) GetUserAvatar(_ context.Context, _ int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (r *stubJWTUserRepo) UpdateUserLastActiveAt(_ context.Context, _ int64, _ time.Time) error {
+ return nil
+}
+
+type recordingActivityToucher struct {
+ userIDs []int64
+}
+
+func (r *recordingActivityToucher) TouchLastActiveForUser(_ context.Context, user *service.User) {
+ if user == nil {
+ return
+ }
+ r.userIDs = append(r.userIDs, user.ID)
+}
+
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。
func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) {
@@ -40,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users}
- authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
+ authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)
@@ -106,6 +126,45 @@ func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
}
+func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
+ user := &service.User{
+ ID: 1,
+ Email: "test@example.com",
+ Role: "user",
+ Status: service.StatusActive,
+ Concurrency: 5,
+ TokenVersion: 1,
+ }
+
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!"
+ cfg.JWT.AccessTokenExpireMinutes = 60
+
+ userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
+ authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
+ userSvc := service.NewUserService(userRepo, nil, nil, nil)
+ toucher := &recordingActivityToucher{}
+
+ r := gin.New()
+ r.Use(jwtAuth(authSvc, userSvc, toucher))
+ r.GET("/protected", func(c *gin.Context) {
+ c.Status(http.StatusOK)
+ })
+
+ token, err := authSvc.GenerateToken(user)
+ require.NoError(t, err)
+
+ w := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/protected", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+ r.ServeHTTP(w, req)
+
+ require.Equal(t, http.StatusOK, w.Code)
+ require.Equal(t, []int64{1}, toucher.userIDs)
+}
+
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
router, _ := newJWTTestEnv(nil)
diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go
index 7021ab2e..398c0351 100644
--- a/backend/internal/server/middleware/security_headers.go
+++ b/backend/internal/server/middleware/security_headers.go
@@ -96,7 +96,8 @@ func isAPIRoutePath(c *gin.Context) bool {
return strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") ||
- strings.HasPrefix(path, "/responses")
+ strings.HasPrefix(path, "/responses") ||
+ strings.HasPrefix(path, "/images")
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights,
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 9af0fd8e..1c786f50 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -88,6 +88,12 @@ func RegisterAdminRoutes(
// 渠道管理
registerChannelRoutes(admin, h)
+
+ // 渠道监控
+ registerChannelMonitorRoutes(admin, h)
+
+ // 邀请返利(专属用户管理)
+ registerAffiliateRoutes(admin, h)
}
}
@@ -212,6 +218,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
+ users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
@@ -220,6 +227,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
+ users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
// User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
@@ -243,6 +251,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
+ groups.PUT("/:id/rpm-overrides", h.Admin.Group.BatchSetGroupRPMOverrides)
+ groups.DELETE("/:id/rpm-overrides", h.Admin.Group.ClearGroupRPMOverrides)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
}
@@ -563,3 +573,42 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
channels.DELETE("/:id", h.Admin.Channel.Delete)
}
}
+
+func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ monitors := admin.Group("/channel-monitors")
+ {
+ monitors.GET("", h.Admin.ChannelMonitor.List)
+ monitors.POST("", h.Admin.ChannelMonitor.Create)
+ monitors.GET("/:id", h.Admin.ChannelMonitor.Get)
+ monitors.PUT("/:id", h.Admin.ChannelMonitor.Update)
+ monitors.DELETE("/:id", h.Admin.ChannelMonitor.Delete)
+ monitors.POST("/:id/run", h.Admin.ChannelMonitor.Run)
+ monitors.GET("/:id/history", h.Admin.ChannelMonitor.History)
+ }
+
+ templates := admin.Group("/channel-monitor-templates")
+ {
+ templates.GET("", h.Admin.ChannelMonitorTemplate.List)
+ templates.POST("", h.Admin.ChannelMonitorTemplate.Create)
+ templates.GET("/:id", h.Admin.ChannelMonitorTemplate.Get)
+ templates.PUT("/:id", h.Admin.ChannelMonitorTemplate.Update)
+ templates.DELETE("/:id", h.Admin.ChannelMonitorTemplate.Delete)
+ templates.GET("/:id/monitors", h.Admin.ChannelMonitorTemplate.AssociatedMonitors)
+ templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply)
+ }
+}
+
+// registerAffiliateRoutes 注册邀请返利的管理端路由(专属用户配置)
+func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ affiliates := admin.Group("/affiliates")
+ {
+ users := affiliates.Group("/users")
+ {
+ users.GET("", h.Admin.Affiliate.ListUsers)
+ users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
+ users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
+ users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
+ users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
+ }
+ }
+}
diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go
index c143b030..642a2103 100644
--- a/backend/internal/server/routes/auth.go
+++ b/backend/internal/server/routes/auth.go
@@ -63,14 +63,90 @@ func RegisterAuthRoutes(
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
+ auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) {
+ query := c.Request.URL.Query()
+ query.Set("intent", "bind_current_user")
+ c.Request.URL.RawQuery = query.Encode()
+ h.Auth.LinuxDoOAuthStart(c)
+ })
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
+ auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart)
+ auth.GET("/oauth/wechat/bind/start", func(c *gin.Context) {
+ query := c.Request.URL.Query()
+ query.Set("intent", "bind_current_user")
+ c.Request.URL.RawQuery = query.Encode()
+ h.Auth.WeChatOAuthStart(c)
+ })
+ auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback)
+ auth.GET("/oauth/wechat/payment/start", h.Auth.WeChatPaymentOAuthStart)
+ auth.GET("/oauth/wechat/payment/callback", h.Auth.WeChatPaymentOAuthCallback)
+ auth.POST("/oauth/pending/exchange",
+ rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.ExchangePendingOAuthCompletion,
+ )
+ auth.POST("/oauth/pending/send-verify-code",
+ rateLimiter.LimitWithOptions("oauth-pending-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.SendPendingOAuthVerifyCode,
+ )
+ auth.POST("/oauth/pending/create-account",
+ rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreatePendingOAuthAccount,
+ )
+ auth.POST("/oauth/pending/bind-login",
+ rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindPendingOAuthLogin,
+ )
auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
+ auth.POST("/oauth/linuxdo/bind-login",
+ rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindLinuxDoOAuthLogin,
+ )
+ auth.POST("/oauth/linuxdo/create-account",
+ rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateLinuxDoOAuthAccount,
+ )
+ auth.POST("/oauth/wechat/complete-registration",
+ rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CompleteWeChatOAuthRegistration,
+ )
+ auth.POST("/oauth/wechat/bind-login",
+ rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindWeChatOAuthLogin,
+ )
+ auth.POST("/oauth/wechat/create-account",
+ rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateWeChatOAuthAccount,
+ )
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
+ auth.GET("/oauth/oidc/bind/start", func(c *gin.Context) {
+ query := c.Request.URL.Query()
+ query.Set("intent", "bind_current_user")
+ c.Request.URL.RawQuery = query.Encode()
+ h.Auth.OIDCOAuthStart(c)
+ })
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{
@@ -78,6 +154,18 @@ func RegisterAuthRoutes(
}),
h.Auth.CompleteOIDCOAuthRegistration,
)
+ auth.POST("/oauth/oidc/bind-login",
+ rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindOIDCOAuthLogin,
+ )
+ auth.POST("/oauth/oidc/create-account",
+ rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateOIDCOAuthAccount,
+ )
}
// 公开设置(无需认证)
@@ -94,5 +182,6 @@ func RegisterAuthRoutes(
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
+ authenticated.POST("/auth/oauth/bind-token", h.Auth.PrepareOAuthBindAccessTokenCookie)
}
}
diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go
index 4f411cec..07a66efb 100644
--- a/backend/internal/server/routes/auth_rate_limit_test.go
+++ b/backend/internal/server/routes/auth_rate_limit_test.go
@@ -52,6 +52,7 @@ func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
"/api/v1/auth/login",
"/api/v1/auth/login/2fa",
"/api/v1/auth/send-verify-code",
+ "/api/v1/auth/oauth/pending/send-verify-code",
}
for _, path := range paths {
diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go
index cbf98293..9541cda1 100644
--- a/backend/internal/server/routes/gateway.go
+++ b/backend/internal/server/routes/gateway.go
@@ -88,6 +88,30 @@ func RegisterGatewayRoutes(
}
h.Gateway.ChatCompletions(c)
})
+ gateway.POST("/images/generations", func(c *gin.Context) {
+ if getGroupPlatform(c) != service.PlatformOpenAI {
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": gin.H{
+ "type": "not_found_error",
+ "message": "Images API is not supported for this platform",
+ },
+ })
+ return
+ }
+ h.OpenAIGateway.Images(c)
+ })
+ gateway.POST("/images/edits", func(c *gin.Context) {
+ if getGroupPlatform(c) != service.PlatformOpenAI {
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": gin.H{
+ "type": "not_found_error",
+ "message": "Images API is not supported for this platform",
+ },
+ })
+ return
+ }
+ h.OpenAIGateway.Images(c)
+ })
}
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
@@ -116,6 +140,13 @@ func RegisterGatewayRoutes(
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
+ codexDirect := r.Group("/backend-api/codex")
+ codexDirect.Use(bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic)
+ {
+ codexDirect.POST("/responses", responsesHandler)
+ codexDirect.POST("/responses/*subpath", responsesHandler)
+ codexDirect.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
+ }
// OpenAI Chat Completions API(不带v1前缀的别名)— auto-route based on group platform
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
@@ -124,6 +155,30 @@ func RegisterGatewayRoutes(
}
h.Gateway.ChatCompletions(c)
})
+ r.POST("/images/generations", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
+ if getGroupPlatform(c) != service.PlatformOpenAI {
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": gin.H{
+ "type": "not_found_error",
+ "message": "Images API is not supported for this platform",
+ },
+ })
+ return
+ }
+ h.OpenAIGateway.Images(c)
+ })
+ r.POST("/images/edits", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
+ if getGroupPlatform(c) != service.PlatformOpenAI {
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": gin.H{
+ "type": "not_found_error",
+ "message": "Images API is not supported for this platform",
+ },
+ })
+ return
+ }
+ h.OpenAIGateway.Images(c)
+ })
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
diff --git a/backend/internal/server/routes/gateway_test.go b/backend/internal/server/routes/gateway_test.go
index 4d65a626..19ef5686 100644
--- a/backend/internal/server/routes/gateway_test.go
+++ b/backend/internal/server/routes/gateway_test.go
@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -24,6 +25,11 @@ func newGatewayRoutesTestRouter() *gin.Engine {
OpenAIGateway: &handler.OpenAIGatewayHandler{},
},
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
+ groupID := int64(1)
+ c.Set(string(servermiddleware.ContextKeyAPIKey), &service.APIKey{
+ GroupID: &groupID,
+ Group: &service.Group{Platform: service.PlatformOpenAI},
+ })
c.Next()
}),
nil,
@@ -39,7 +45,12 @@ func newGatewayRoutesTestRouter() *gin.Engine {
func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) {
router := newGatewayRoutesTestRouter()
- for _, path := range []string{"/v1/responses/compact", "/responses/compact"} {
+ for _, path := range []string{
+ "/v1/responses/compact",
+ "/responses/compact",
+ "/backend-api/codex/responses",
+ "/backend-api/codex/responses/compact",
+ } {
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-5"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -48,3 +59,21 @@ func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) {
require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path)
}
}
+
+func TestGatewayRoutesOpenAIImagesPathsAreRegistered(t *testing.T) {
+ router := newGatewayRoutesTestRouter()
+
+ for _, path := range []string{
+ "/v1/images/generations",
+ "/v1/images/edits",
+ "/images/generations",
+ "/images/edits",
+ } {
+ req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-image-2","prompt":"draw a cat"}`))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+
+ router.ServeHTTP(w, req)
+ require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI images handler", path)
+ }
+}
diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go
index 23bd58ad..e4828ead 100644
--- a/backend/internal/server/routes/payment.go
+++ b/backend/internal/server/routes/payment.go
@@ -44,11 +44,13 @@ func RegisterPaymentRoutes(
}
// --- Public payment endpoints (no auth) ---
- // Payment result page needs to verify order status without login
- // (user session may have expired during provider redirect).
+ // Signed resume-token recovery is the preferred public lookup path.
+ // The legacy anonymous out_trade_no verify endpoint remains available as a
+ // persisted-state compatibility path for staggered upgrades.
public := v1.Group("/payment/public")
{
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
+ public.POST("/orders/resolve", paymentHandler.ResolveOrderPublicByResumeToken)
}
// --- Webhook endpoints (no auth) ---
diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go
index d004f8b4..9976954c 100644
--- a/backend/internal/server/routes/user.go
+++ b/backend/internal/server/routes/user.go
@@ -25,6 +25,12 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
+ user.GET("/aff", h.User.GetAffiliate)
+ user.POST("/aff/transfer", h.User.TransferAffiliateQuota)
+ user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
+ user.POST("/account-bindings/email", h.User.BindEmailIdentity)
+ user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
+ user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
// 通知邮箱管理
notifyEmail := user.Group("/notify-email")
@@ -64,6 +70,12 @@ func RegisterUserRoutes(
groups.GET("/rates", h.APIKey.GetUserGroupRates)
}
+ // 用户可用渠道(非管理员接口)
+ channels := authenticated.Group("/channels")
+ {
+ channels.GET("/available", h.AvailableChannel.List)
+ }
+
// 使用记录
usage := authenticated.Group("/usage")
{
@@ -99,5 +111,12 @@ func RegisterUserRoutes(
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
+
+ // 渠道监控(用户只读)
+ monitors := authenticated.Group("/channel-monitors")
+ {
+ monitors.GET("", h.ChannelMonitor.List)
+ monitors.GET("/:id/status", h.ChannelMonitor.GetStatus)
+ }
}
}
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 52db3073..cd06ffa3 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -121,6 +121,9 @@ func (a *Account) IsSchedulable() bool {
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
+ if a.IsAPIKeyOrBedrock() && a.IsQuotaExceeded() {
+ return false
+ }
return true
}
@@ -390,6 +393,56 @@ func parseTempUnschedInt(value any) int {
return 0
}
+const (
+ // OpenAICompactModeAuto follows compact-probe results when deciding compact eligibility.
+ OpenAICompactModeAuto = "auto"
+ // OpenAICompactModeForceOn always treats the account as compact-supported.
+ OpenAICompactModeForceOn = "force_on"
+ // OpenAICompactModeForceOff always treats the account as compact-unsupported.
+ OpenAICompactModeForceOff = "force_off"
+)
+
+func normalizeOpenAICompactMode(mode string) string {
+ switch strings.ToLower(strings.TrimSpace(mode)) {
+ case OpenAICompactModeForceOn:
+ return OpenAICompactModeForceOn
+ case OpenAICompactModeForceOff:
+ return OpenAICompactModeForceOff
+ default:
+ return OpenAICompactModeAuto
+ }
+}
+
+func stringMappingFromRaw(raw any) map[string]string {
+ switch mapping := raw.(type) {
+ case map[string]any:
+ if len(mapping) == 0 {
+ return nil
+ }
+ result := make(map[string]string, len(mapping))
+ for key, value := range mapping {
+ if str, ok := value.(string); ok {
+ result[key] = str
+ }
+ }
+ if len(result) == 0 {
+ return nil
+ }
+ return result
+ case map[string]string:
+ if len(mapping) == 0 {
+ return nil
+ }
+ result := make(map[string]string, len(mapping))
+ for key, value := range mapping {
+ result[key] = value
+ }
+ return result
+ default:
+ return nil
+ }
+}
+
func (a *Account) GetModelMapping() map[string]string {
credentialsPtr := mapPtr(a.Credentials)
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
@@ -595,6 +648,77 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
return requestedModel, false
}
+// GetOpenAICompactMode returns the compact routing mode for an OpenAI account.
+// Missing or invalid values fall back to "auto".
+func (a *Account) GetOpenAICompactMode() string {
+ if a == nil || !a.IsOpenAI() || a.Extra == nil {
+ return OpenAICompactModeAuto
+ }
+ mode, _ := a.Extra["openai_compact_mode"].(string)
+ return normalizeOpenAICompactMode(mode)
+}
+
+// OpenAICompactSupportKnown reports whether compact capability is known for this
+// account and, when known, whether it is supported.
+func (a *Account) OpenAICompactSupportKnown() (supported bool, known bool) {
+ if a == nil || !a.IsOpenAI() {
+ return false, false
+ }
+
+ switch a.GetOpenAICompactMode() {
+ case OpenAICompactModeForceOn:
+ return true, true
+ case OpenAICompactModeForceOff:
+ return false, true
+ }
+
+ if a.Extra == nil {
+ return false, false
+ }
+ supported, ok := a.Extra["openai_compact_supported"].(bool)
+ if !ok {
+ return false, false
+ }
+ return supported, true
+}
+
+// AllowsOpenAICompact reports whether the account may be considered for compact
+// requests. Unknown capability remains allowed to avoid breaking older accounts
+// before an explicit probe has been run.
+func (a *Account) AllowsOpenAICompact() bool {
+ if a == nil || !a.IsOpenAI() {
+ return false
+ }
+ supported, known := a.OpenAICompactSupportKnown()
+ if !known {
+ return true
+ }
+ return supported
+}
+
+// GetCompactModelMapping returns compact-only model remapping configuration.
+// This mapping is intended for /responses/compact only and does not affect
+// normal /responses traffic.
+func (a *Account) GetCompactModelMapping() map[string]string {
+ if a == nil || a.Credentials == nil {
+ return nil
+ }
+ return stringMappingFromRaw(a.Credentials["compact_model_mapping"])
+}
+
+// ResolveCompactMappedModel resolves compact-only model remapping and reports
+// whether a compact-specific mapping rule matched.
+func (a *Account) ResolveCompactMappedModel(requestedModel string) (mappedModel string, matched bool) {
+ mapping := a.GetCompactModelMapping()
+ if len(mapping) == 0 {
+ return requestedModel, false
+ }
+ if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched {
+ return mappedModel, true
+ }
+ return requestedModel, false
+}
+
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeAPIKey {
return ""
@@ -908,6 +1032,32 @@ func (a *Account) GetChatGPTAccountID() string {
return a.GetCredential("chatgpt_account_id")
}
+func (a *Account) GetOpenAIDeviceID() string {
+ if !a.IsOpenAIOAuth() {
+ return ""
+ }
+ return strings.TrimSpace(a.GetExtraString("openai_device_id"))
+}
+
+func (a *Account) GetOpenAISessionID() string {
+ if !a.IsOpenAIOAuth() {
+ return ""
+ }
+ return strings.TrimSpace(a.GetExtraString("openai_session_id"))
+}
+
+func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool {
+ if !a.IsOpenAI() {
+ return false
+ }
+ switch capability {
+ case OpenAIImagesCapabilityBasic, OpenAIImagesCapabilityNative:
+ return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey
+ default:
+ return true
+ }
+}
+
func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() {
return ""
diff --git a/backend/internal/service/account_openai_compact_test.go b/backend/internal/service/account_openai_compact_test.go
new file mode 100644
index 00000000..442b00da
--- /dev/null
+++ b/backend/internal/service/account_openai_compact_test.go
@@ -0,0 +1,369 @@
+package service
+
+import "testing"
+
+func TestAccountGetOpenAICompactMode(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ want string
+ }{
+ {
+ name: "nil account defaults to auto",
+ want: OpenAICompactModeAuto,
+ },
+ {
+ name: "non openai account defaults to auto",
+ account: &Account{
+ Platform: PlatformAnthropic,
+ Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
+ },
+ want: OpenAICompactModeAuto,
+ },
+ {
+ name: "missing extra defaults to auto",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ },
+ want: OpenAICompactModeAuto,
+ },
+ {
+ name: "invalid mode falls back to auto",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_mode": " invalid "},
+ },
+ want: OpenAICompactModeAuto,
+ },
+ {
+ name: "force on is normalized",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_mode": " FORCE_ON "},
+ },
+ want: OpenAICompactModeForceOn,
+ },
+ {
+ name: "force off is normalized",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_mode": "force_off"},
+ },
+ want: OpenAICompactModeForceOff,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.account.GetOpenAICompactMode(); got != tt.want {
+ t.Fatalf("GetOpenAICompactMode() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestAccountOpenAICompactSupportKnown(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ wantSupported bool
+ wantKnown bool
+ }{
+ {
+ name: "nil account is unknown",
+ wantSupported: false,
+ wantKnown: false,
+ },
+ {
+ name: "non openai account is unknown",
+ account: &Account{
+ Platform: PlatformAnthropic,
+ Extra: map[string]any{"openai_compact_supported": true},
+ },
+ wantSupported: false,
+ wantKnown: false,
+ },
+ {
+ name: "force on overrides probe state",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{
+ "openai_compact_mode": OpenAICompactModeForceOn,
+ "openai_compact_supported": false,
+ },
+ },
+ wantSupported: true,
+ wantKnown: true,
+ },
+ {
+ name: "force off overrides probe state",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{
+ "openai_compact_mode": OpenAICompactModeForceOff,
+ "openai_compact_supported": true,
+ },
+ },
+ wantSupported: false,
+ wantKnown: true,
+ },
+ {
+ name: "auto true is known supported",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_supported": true},
+ },
+ wantSupported: true,
+ wantKnown: true,
+ },
+ {
+ name: "auto false is known unsupported",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_supported": false},
+ },
+ wantSupported: false,
+ wantKnown: true,
+ },
+ {
+ name: "auto without probe state remains unknown",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{},
+ },
+ wantSupported: false,
+ wantKnown: false,
+ },
+ {
+ name: "invalid probe field remains unknown",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_supported": "true"},
+ },
+ wantSupported: false,
+ wantKnown: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotSupported, gotKnown := tt.account.OpenAICompactSupportKnown()
+ if gotSupported != tt.wantSupported || gotKnown != tt.wantKnown {
+ t.Fatalf("OpenAICompactSupportKnown() = (%v, %v), want (%v, %v)", gotSupported, gotKnown, tt.wantSupported, tt.wantKnown)
+ }
+ })
+ }
+}
+
+func TestAccountAllowsOpenAICompact(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ want bool
+ }{
+ {
+ name: "nil account does not allow compact",
+ want: false,
+ },
+ {
+ name: "non openai account does not allow compact",
+ account: &Account{
+ Platform: PlatformAnthropic,
+ },
+ want: false,
+ },
+ {
+ name: "unknown openai account remains allowed",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{},
+ },
+ want: true,
+ },
+ {
+ name: "supported openai account is allowed",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_supported": true},
+ },
+ want: true,
+ },
+ {
+ name: "unsupported openai account is rejected",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_supported": false},
+ },
+ want: false,
+ },
+ {
+ name: "force on is allowed",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
+ },
+ want: true,
+ },
+ {
+ name: "force off is rejected",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff},
+ },
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.account.AllowsOpenAICompact(); got != tt.want {
+ t.Fatalf("AllowsOpenAICompact() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestAccountGetCompactModelMapping(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ want map[string]string
+ }{
+ {
+ name: "nil account returns nil",
+ want: nil,
+ },
+ {
+ name: "missing credentials returns nil",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ },
+ want: nil,
+ },
+ {
+ name: "map any is converted",
+ account: &Account{
+ Credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-5.4": "gpt-5.4-openai-compact",
+ "invalid": 1,
+ },
+ },
+ },
+ want: map[string]string{
+ "gpt-5.4": "gpt-5.4-openai-compact",
+ },
+ },
+ {
+ name: "map string string is copied",
+ account: &Account{
+ Credentials: map[string]any{
+ "compact_model_mapping": map[string]string{
+ "gpt-*": "compact-*",
+ },
+ },
+ },
+ want: map[string]string{
+ "gpt-*": "compact-*",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.account.GetCompactModelMapping()
+ if !equalStringMap(got, tt.want) {
+ t.Fatalf("GetCompactModelMapping() = %#v, want %#v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestAccountResolveCompactMappedModel(t *testing.T) {
+ tests := []struct {
+ name string
+ credentials map[string]any
+ requestedModel string
+ expectedModel string
+ expectedMatch bool
+ }{
+ {
+ name: "no compact mapping reports unmatched",
+ credentials: nil,
+ requestedModel: "gpt-5.4",
+ expectedModel: "gpt-5.4",
+ expectedMatch: false,
+ },
+ {
+ name: "exact compact mapping matches",
+ credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-5.4": "gpt-5.4-openai-compact",
+ },
+ },
+ requestedModel: "gpt-5.4",
+ expectedModel: "gpt-5.4-openai-compact",
+ expectedMatch: true,
+ },
+ {
+ name: "exact passthrough counts as match",
+ credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-5.4": "gpt-5.4",
+ },
+ },
+ requestedModel: "gpt-5.4",
+ expectedModel: "gpt-5.4",
+ expectedMatch: true,
+ },
+ {
+ name: "longest wildcard wins",
+ credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-*": "fallback-compact",
+ "gpt-5.4*": "gpt-5.4-openai-compact",
+ "gpt-5.4-mini*": "gpt-5.4-mini-openai-compact",
+ },
+ },
+ requestedModel: "gpt-5.4-mini",
+ expectedModel: "gpt-5.4-mini-openai-compact",
+ expectedMatch: true,
+ },
+ {
+ name: "missing compact mapping reports unmatched",
+ credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-5.3": "gpt-5.3-openai-compact",
+ },
+ },
+ requestedModel: "gpt-5.4",
+ expectedModel: "gpt-5.4",
+ expectedMatch: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Credentials: tt.credentials,
+ }
+ gotModel, gotMatch := account.ResolveCompactMappedModel(tt.requestedModel)
+ if gotModel != tt.expectedModel || gotMatch != tt.expectedMatch {
+ t.Fatalf("ResolveCompactMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, gotModel, gotMatch, tt.expectedModel, tt.expectedMatch)
+ }
+ })
+ }
+}
+
+func equalStringMap(left, right map[string]string) bool {
+ if len(left) != len(right) {
+ return false
+ }
+ for key, want := range right {
+ if got, ok := left[key]; !ok || got != want {
+ return false
+ }
+ }
+ return true
+}
diff --git a/backend/internal/service/account_quota_schedulable_test.go b/backend/internal/service/account_quota_schedulable_test.go
new file mode 100644
index 00000000..2895b34c
--- /dev/null
+++ b/backend/internal/service/account_quota_schedulable_test.go
@@ -0,0 +1,123 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccountIsSchedulable_QuotaExceeded(t *testing.T) {
+ now := time.Now()
+
+ tests := []struct {
+ name string
+ account *Account
+ want bool
+ }{
+ {
+ name: "apikey daily quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey weekly quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_weekly_limit": 50.0,
+ "quota_weekly_used": 50.0,
+ "quota_weekly_start": now.Add(-2 * 24 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey total quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_limit": 100.0,
+ "quota_used": 100.0,
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey quota not exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 5.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "apikey expired daily period restores schedulable",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-25 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "oauth ignores quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "bedrock quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeBedrock,
+ Extra: map[string]any{
+ "quota_limit": 200.0,
+ "quota_used": 200.0,
+ },
+ },
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, tt.account.IsSchedulable())
+ })
+ }
+}
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index a5559b7d..391e7475 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -52,12 +52,19 @@ type TestEvent struct {
const (
defaultGeminiTextTestPrompt = "hi"
defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
+ defaultOpenAIImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
)
+// isOpenAIImageModel checks if the model is an OpenAI image generation model (e.g. gpt-image-2).
+func isOpenAIImageModel(model string) bool {
+ return strings.HasPrefix(strings.ToLower(model), "gpt-image-")
+}
+
// AccountTestService handles account testing operations
type AccountTestService struct {
accountRepo AccountRepository
geminiTokenProvider *GeminiTokenProvider
+ claudeTokenProvider *ClaudeTokenProvider
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
@@ -68,6 +75,7 @@ type AccountTestService struct {
func NewAccountTestService(
accountRepo AccountRepository,
geminiTokenProvider *GeminiTokenProvider,
+ claudeTokenProvider *ClaudeTokenProvider,
antigravityGatewayService *AntigravityGatewayService,
httpUpstream HTTPUpstream,
cfg *config.Config,
@@ -76,6 +84,7 @@ func NewAccountTestService(
return &AccountTestService{
accountRepo: accountRepo,
geminiTokenProvider: geminiTokenProvider,
+ claudeTokenProvider: claudeTokenProvider,
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
@@ -159,7 +168,8 @@ func createTestPayload(modelID string) (map[string]any, error) {
// TestAccountConnection tests an account's connection by sending a test request
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
-func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
+// mode is optional - "compact" routes OpenAI accounts to the /responses/compact probe path
+func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
// Get account
@@ -170,7 +180,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
// Route to platform-specific test method
if account.IsOpenAI() {
- return s.testOpenAIAccountConnection(c, account, modelID)
+ return s.testOpenAIAccountConnection(c, account, modelID, prompt, normalizeAccountTestMode(mode))
}
if account.IsGemini() {
@@ -203,6 +213,9 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
if account.IsBedrock() {
return s.testBedrockAccountConnection(c, ctx, account, testModelID)
}
+ if account.Type == AccountTypeServiceAccount {
+ return s.testClaudeVertexServiceAccountConnection(c, ctx, account, testModelID)
+ }
// Determine authentication method and API URL
var authToken string
@@ -306,6 +319,74 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
return s.processClaudeStream(c, resp.Body)
}
+func (s *AccountTestService) testClaudeVertexServiceAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
+ if mappedModel, matched := account.ResolveMappedModel(testModelID); matched {
+ testModelID = mappedModel
+ } else {
+ testModelID = normalizeVertexAnthropicModelID(claude.NormalizeModelID(testModelID))
+ }
+
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.Flush()
+
+ payload, err := createTestPayload(testModelID)
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create test payload")
+ }
+ payloadBytes, _ := json.Marshal(payload)
+ vertexBody, err := buildVertexAnthropicRequestBody(payloadBytes)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Vertex request body: %s", err.Error()))
+ }
+
+ if s.claudeTokenProvider == nil {
+ return s.sendErrorAndEnd(c, "Claude token provider not configured")
+ }
+ accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to get service account access token: %s", err.Error()))
+ }
+
+ fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(testModelID), testModelID, true)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build Vertex URL: %s", err.Error()))
+ }
+
+ s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create request")
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ errMsg := fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))
+ if resp.StatusCode == http.StatusForbidden {
+ _ = s.accountRepo.SetError(ctx, account.ID, errMsg)
+ }
+ return s.sendErrorAndEnd(c, errMsg)
+ }
+
+ return s.processClaudeStream(c, resp.Body)
+}
+
// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke
func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error {
region := bedrockRuntimeRegion(account)
@@ -410,8 +491,10 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
}
// testOpenAIAccountConnection tests an OpenAI account's connection
-func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
+func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
+ _ = prompt
+ mode = normalizeAccountTestMode(mode)
// Default to openai.DefaultTestModel for OpenAI testing
testModelID := modelID
@@ -419,14 +502,24 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
testModelID = openai.DefaultTestModel
}
- // For API Key accounts with model mapping, map the model
- if account.Type == "apikey" {
- mapping := account.GetModelMapping()
- if len(mapping) > 0 {
- if mappedModel, exists := mapping[testModelID]; exists {
- testModelID = mappedModel
- }
+ // Align test routing with gateway behavior: OpenAI accounts apply normal
+ // account model mapping, and compact mode applies compact-only mapping on top.
+ testModelID = account.GetMappedModel(testModelID)
+ if mode == AccountTestModeCompact {
+ testModelID = resolveOpenAICompactForwardModel(account, testModelID)
+ return s.testOpenAICompactConnection(c, account, testModelID)
+ }
+
+ // Route to image generation test if an image model is selected
+ if isOpenAIImageModel(testModelID) {
+ imagePrompt := strings.TrimSpace(prompt)
+ if imagePrompt == "" {
+ imagePrompt = defaultOpenAIImageTestPrompt
}
+ if account.Type == "apikey" {
+ return s.testOpenAIImageAPIKey(c, ctx, account, testModelID, imagePrompt)
+ }
+ return s.testOpenAIImageOAuth(c, ctx, account, testModelID, imagePrompt)
}
// Determine authentication method and API URL
@@ -519,6 +612,9 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
+ if resp.StatusCode == http.StatusTooManyRequests {
+ s.reconcileOpenAI429State(ctx, account, resp.Header, body)
+ }
// 401 Unauthorized: 标记账号为永久错误
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
@@ -531,6 +627,154 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
return s.processOpenAIStream(c, resp.Body)
}
+// testOpenAICompactConnection probes /responses/compact and persists the
+// resulting capability state on the account.
+func (s *AccountTestService) testOpenAICompactConnection(c *gin.Context, account *Account, testModelID string) error {
+ ctx := c.Request.Context()
+
+ authToken := ""
+ apiURL := ""
+ isOAuth := false
+ chatgptAccountID := ""
+
+ switch {
+ case account.IsOAuth():
+ isOAuth = true
+ authToken = account.GetOpenAIAccessToken()
+ if authToken == "" {
+ return s.sendErrorAndEnd(c, "No access token available")
+ }
+ apiURL = chatgptCodexAPIURL + "/compact"
+ chatgptAccountID = account.GetChatGPTAccountID()
+ case account.Type == AccountTypeAPIKey:
+ authToken = account.GetOpenAIApiKey()
+ if authToken == "" {
+ return s.sendErrorAndEnd(c, "No API key available")
+ }
+ baseURL := account.GetOpenAIBaseURL()
+ if baseURL == "" {
+ baseURL = "https://api.openai.com"
+ }
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
+ }
+ apiURL = appendOpenAIResponsesRequestPathSuffix(buildOpenAIResponsesURL(normalizedBaseURL), "/compact")
+ default:
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
+ }
+
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.Flush()
+
+ payloadBytes, _ := json.Marshal(createOpenAICompactProbePayload(testModelID))
+ s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
+
+ req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create request")
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Bearer "+authToken)
+ req.Header.Set("OpenAI-Beta", "responses=experimental")
+ req.Header.Set("Originator", "codex_cli_rs")
+ req.Header.Set("User-Agent", codexCLIUserAgent)
+ req.Header.Set("Version", codexCLIVersion)
+ probeSessionID := compactProbeSessionID(account.ID)
+ req.Header.Set("Session_ID", probeSessionID)
+ req.Header.Set("Conversation_ID", probeSessionID)
+
+ if isOAuth {
+ req.Host = "chatgpt.com"
+ if chatgptAccountID != "" {
+ req.Header.Set("chatgpt-account-id", chatgptAccountID)
+ }
+ }
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
+ if err != nil {
+ if s.accountRepo != nil {
+ updates := buildOpenAICompactProbeExtraUpdates(nil, nil, err, time.Now())
+ _ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
+ mergeAccountExtra(account, updates)
+ }
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+
+ if s.accountRepo != nil {
+ updates := buildOpenAICompactProbeExtraUpdates(resp, body, nil, time.Now())
+ if codexUpdates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(codexUpdates) > 0 {
+ updates = mergeExtraUpdates(updates, codexUpdates)
+ }
+ if len(updates) > 0 {
+ _ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
+ mergeAccountExtra(account, updates)
+ }
+ // 探测如返回 429,主动同步限流状态,避免后续短时间内继续选中。
+ if resp.StatusCode == http.StatusTooManyRequests {
+ s.reconcileOpenAI429State(ctx, account, resp.Header, body)
+ }
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
+ errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
+ _ = s.accountRepo.SetError(ctx, account.ID, errMsg)
+ }
+ return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
+ }
+
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Compact probe succeeded"})
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+}
+
+func (s *AccountTestService) reconcileOpenAI429State(ctx context.Context, account *Account, headers http.Header, body []byte) {
+ if s == nil || s.accountRepo == nil || account == nil {
+ return
+ }
+
+ var resetAt *time.Time
+ if calculated := calculateOpenAI429ResetTime(headers); calculated != nil {
+ resetAt = calculated
+ } else if unixTs := parseOpenAIRateLimitResetTime(body); unixTs != nil {
+ t := time.Unix(*unixTs, 0)
+ resetAt = &t
+ }
+ if resetAt == nil {
+ return
+ }
+
+ if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
+ return
+ }
+
+ now := time.Now()
+ account.RateLimitedAt = &now
+ account.RateLimitResetAt = resetAt
+
+ if account.Status == StatusError {
+ if err := s.accountRepo.ClearError(ctx, account.ID); err != nil {
+ return
+ }
+ account.Status = StatusActive
+ account.ErrorMessage = ""
+ }
+}
+
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
@@ -541,8 +785,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
testModelID = geminicli.DefaultTestModel
}
- // For API Key accounts with model mapping, map the model
- if account.Type == AccountTypeAPIKey {
+ // For static upstream credentials with model mapping, map the model
+ if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
@@ -570,6 +814,8 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
+ case AccountTypeServiceAccount:
+ req, err = s.buildGeminiServiceAccountRequest(ctx, account, testModelID, payload)
default:
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
@@ -723,6 +969,27 @@ func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, accoun
return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload)
}
+func (s *AccountTestService) buildGeminiServiceAccountRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
+ if s.geminiTokenProvider == nil {
+ return nil, fmt.Errorf("gemini token provider not configured")
+ }
+ accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get service account access token: %w", err)
+ }
+ fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, "streamGenerateContent", true)
+ if err != nil {
+ return nil, err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ return req, nil
+}
+
// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity)
func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) {
var inner map[string]any
@@ -975,13 +1242,17 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
// processOpenAIStream processes the SSE stream from OpenAI Responses API
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
+ seenCompleted := false
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
+ if seenCompleted {
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+ }
+ return s.sendErrorAndEnd(c, "Stream ended before response.completed")
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
}
@@ -993,8 +1264,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
if jsonStr == "[DONE]" {
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
+ if seenCompleted {
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+ }
+ return s.sendErrorAndEnd(c, "Stream ended before response.completed")
}
var data map[string]any
@@ -1010,9 +1284,19 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
if delta, ok := data["delta"].(string); ok && delta != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
}
- case "response.completed":
+ case "response.completed", "response.done":
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
+ case "response.failed":
+ errorMsg := "OpenAI response failed"
+ if responseData, ok := data["response"].(map[string]any); ok {
+ if errData, ok := responseData["error"].(map[string]any); ok {
+ if msg, ok := errData["message"].(string); ok && msg != "" {
+ errorMsg = msg
+ }
+ }
+ }
+ return s.sendErrorAndEnd(c, errorMsg)
case "error":
errorMsg := "Unknown error"
if errData, ok := data["error"].(map[string]any); ok {
@@ -1025,7 +1309,198 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
}
}
-// sendEvent sends a SSE event to the client
+// testOpenAIImageAPIKey tests OpenAI image generation using an API Key account.
+func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error {
+ authToken := account.GetOpenAIApiKey()
+ if authToken == "" {
+ return s.sendErrorAndEnd(c, "No API key available")
+ }
+
+ baseURL := account.GetOpenAIBaseURL()
+ if baseURL == "" {
+ baseURL = "https://api.openai.com"
+ }
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
+ }
+ apiURL := buildOpenAIImagesURL(normalizedBaseURL, openAIImagesGenerationsEndpoint)
+
+ // Set SSE headers
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.Flush()
+
+ s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID})
+
+ payload := map[string]any{
+ "model": modelID,
+ "prompt": prompt,
+ "n": 1,
+ "response_format": "b64_json",
+ }
+ payloadBytes, _ := json.Marshal(payload)
+
+ req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create request")
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+authToken)
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read response: %s", err.Error()))
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
+ }
+
+ // Parse {"data": [{"b64_json": "...", "revised_prompt": "..."}]}
+ var result struct {
+ Data []struct {
+ B64JSON string `json:"b64_json"`
+ RevisedPrompt string `json:"revised_prompt"`
+ } `json:"data"`
+ }
+ if err := json.Unmarshal(body, &result); err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error()))
+ }
+
+ if len(result.Data) == 0 {
+ return s.sendErrorAndEnd(c, "No images returned from API")
+ }
+
+ for _, item := range result.Data {
+ if item.RevisedPrompt != "" {
+ s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt})
+ }
+ if item.B64JSON != "" {
+ s.sendEvent(c, TestEvent{
+ Type: "image",
+ ImageURL: "data:image/png;base64," + item.B64JSON,
+ MimeType: "image/png",
+ })
+ }
+ }
+
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+}
+
+// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via Codex /responses API.
+func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error {
+ authToken := account.GetOpenAIAccessToken()
+ if authToken == "" {
+ return s.sendErrorAndEnd(c, "No access token available")
+ }
+
+ // Set SSE headers
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.Flush()
+
+ s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID})
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Calling Codex /responses image tool...\n"})
+
+ parsed := &OpenAIImagesRequest{
+ Endpoint: openAIImagesGenerationsEndpoint,
+ Model: strings.TrimSpace(modelID),
+ Prompt: prompt,
+ }
+ applyOpenAIImagesDefaults(parsed)
+
+ responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, parsed.Model)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build image request: %s", err.Error()))
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexAPIURL, bytes.NewReader(responsesBody))
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create request")
+ }
+ req.Host = "chatgpt.com"
+ req.Header.Set("Authorization", "Bearer "+authToken)
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "text/event-stream")
+ req.Header.Set("OpenAI-Beta", "responses=experimental")
+ req.Header.Set("originator", "opencode")
+ if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
+ req.Header.Set("User-Agent", customUA)
+ } else {
+ req.Header.Set("User-Agent", codexCLIUserAgent)
+ }
+ if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
+ req.Header.Set("chatgpt-account-id", chatgptAccountID)
+ }
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+ resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Responses API request failed: %s", err.Error()))
+ }
+ defer func() {
+ if resp != nil && resp.Body != nil {
+ _ = resp.Body.Close()
+ }
+ }()
+ if resp.StatusCode >= 400 {
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ message := strings.TrimSpace(extractUpstreamErrorMessage(body))
+ if message == "" {
+ message = fmt.Sprintf("Responses API returned %d", resp.StatusCode)
+ }
+ return s.sendErrorAndEnd(c, message)
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read image response: %s", err.Error()))
+ }
+
+ results, _, _, _, _, err := collectOpenAIImagesFromResponsesBody(body)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse image response: %s", err.Error()))
+ }
+ if len(results) == 0 {
+ return s.sendErrorAndEnd(c, "No images returned from responses API")
+ }
+
+ for _, item := range results {
+ if item.RevisedPrompt != "" {
+ s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt})
+ }
+ mimeType := openAIImageOutputMIMEType(item.OutputFormat)
+ s.sendEvent(c, TestEvent{
+ Type: "image",
+ ImageURL: "data:" + mimeType + ";base64," + item.Result,
+ MimeType: mimeType,
+ })
+ }
+
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+}
+
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
eventJSON, _ := json.Marshal(event)
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
@@ -1051,7 +1526,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)
- testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
+ testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "", AccountTestModeDefault)
finishedAt := time.Now()
body := w.Body.String()
diff --git a/backend/internal/service/account_test_service_openai_compact_test.go b/backend/internal/service/account_test_service_openai_compact_test.go
new file mode 100644
index 00000000..9eb98fdc
--- /dev/null
+++ b/backend/internal/service/account_test_service_openai_compact_test.go
@@ -0,0 +1,199 @@
+package service
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestAccountTestService_TestAccountConnection_OpenAICompactOAuthSuccessPersistsSupport(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ updateCalls := make(chan map[string]any, 1)
+ account := Account{
+ ID: 1,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ "chatgpt_account_id": "chatgpt-acc",
+ },
+ }
+ repo := &snapshotUpdateAccountRepo{
+ stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ updateExtraCalls: updateCalls,
+ }
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-probe"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe","status":"completed"}`)),
+ }}
+ svc := &AccountTestService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ }
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", bytes.NewReader(nil))
+
+ err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
+ require.NoError(t, err)
+
+ require.Equal(t, chatgptCodexAPIURL+"/compact", upstream.lastReq.URL.String())
+ require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
+ require.Equal(t, "application/json", upstream.lastReq.Header.Get("Accept"))
+ require.Equal(t, codexCLIVersion, upstream.lastReq.Header.Get("Version"))
+ require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id"))
+ require.Equal(t, codexCLIUserAgent, upstream.lastReq.Header.Get("User-Agent"))
+ require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id"))
+ require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
+
+ updates := <-updateCalls
+ require.Equal(t, true, updates["openai_compact_supported"])
+ require.Equal(t, http.StatusOK, updates["openai_compact_last_status"])
+ require.Contains(t, rec.Body.String(), `"type":"test_complete"`)
+}
+
+func TestAccountTestService_TestAccountConnection_OpenAICompactOAuth404MarksUnsupported(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ updateCalls := make(chan map[string]any, 1)
+ account := Account{
+ ID: 2,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ "chatgpt_account_id": "chatgpt-acc",
+ },
+ }
+ repo := &snapshotUpdateAccountRepo{
+ stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ updateExtraCalls: updateCalls,
+ }
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusNotFound,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`404 page not found`)),
+ }}
+ svc := &AccountTestService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ }
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/2/test", bytes.NewReader(nil))
+
+ err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
+ require.Error(t, err)
+
+ updates := <-updateCalls
+ require.Equal(t, false, updates["openai_compact_supported"])
+ require.Equal(t, http.StatusNotFound, updates["openai_compact_last_status"])
+ require.Contains(t, rec.Body.String(), `"type":"error"`)
+}
+
+func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyUsesCompactPath(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ updateCalls := make(chan map[string]any, 1)
+ account := Account{
+ ID: 3,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": "https://example.com/v1",
+ "compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
+ },
+ }
+ repo := &snapshotUpdateAccountRepo{
+ stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ updateExtraCalls: updateCalls,
+ }
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey","status":"completed"}`)),
+ }}
+ svc := &AccountTestService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
+ }
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/test", bytes.NewReader(nil))
+
+ err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
+ require.NoError(t, err)
+
+ require.Equal(t, "https://example.com/v1/responses/compact", upstream.lastReq.URL.String())
+ require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
+ updates := <-updateCalls
+ require.Equal(t, true, updates["openai_compact_supported"])
+}
+
+func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyDefaultBaseURLUsesV1Path(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ updateCalls := make(chan map[string]any, 1)
+ account := Account{
+ ID: 4,
+ Name: "openai-apikey-default",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ }
+ repo := &snapshotUpdateAccountRepo{
+ stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ updateExtraCalls: updateCalls,
+ }
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey_default","status":"completed"}`)),
+ }}
+ svc := &AccountTestService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
+ }
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/4/test", bytes.NewReader(nil))
+
+ err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
+ require.NoError(t, err)
+ require.Equal(t, "https://api.openai.com/v1/responses/compact", upstream.lastReq.URL.String())
+ <-updateCalls
+}
diff --git a/backend/internal/service/account_test_service_openai_image_test.go b/backend/internal/service/account_test_service_openai_image_test.go
new file mode 100644
index 00000000..257159c4
--- /dev/null
+++ b/backend/internal/service/account_test_service_openai_image_test.go
@@ -0,0 +1,90 @@
+package service
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"text/event-stream"},
+ },
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" +
+ "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000006,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ }
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 53,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "token-123",
+ },
+ }
+
+ err := svc.testOpenAIImageOAuth(c, context.Background(), account, "gpt-image-2", "draw a cat")
+ require.NoError(t, err)
+ require.Contains(t, rec.Body.String(), "Calling Codex /responses image tool")
+ require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
+ require.Contains(t, rec.Body.String(), "\"success\":true")
+}
+
+func TestAccountTestService_OpenAIImageAPIKeyUsesConfiguredV1BaseURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ },
+ Body: io.NopCloser(strings.NewReader(`{"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
+ },
+ }
+ svc := &AccountTestService{
+ httpUpstream: upstream,
+ cfg: &config.Config{},
+ }
+ account := &Account{
+ ID: 54,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "test-api-key",
+ "base_url": "https://image-upstream.example/v1",
+ },
+ }
+
+ err := svc.testOpenAIImageAPIKey(c, context.Background(), account, "gpt-image-2", "draw a cat")
+ require.NoError(t, err)
+ require.NotNil(t, upstream.lastReq)
+ require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
+ require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
+ require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
+ require.Contains(t, rec.Body.String(), "\"success\":true")
+}
diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go
index 82606979..56204be3 100644
--- a/backend/internal/service/account_test_service_openai_test.go
+++ b/backend/internal/service/account_test_service_openai_test.go
@@ -61,9 +61,12 @@ func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
- updatedExtra map[string]any
- rateLimitedID int64
- rateLimitedAt *time.Time
+ updatedExtra map[string]any
+ rateLimitedID int64
+ rateLimitedAt *time.Time
+ clearedErrorID int64
+ setErrorID int64
+ setErrorMsg string
}
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
@@ -77,6 +80,17 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
return nil
}
+func (r *openAIAccountTestRepo) ClearError(_ context.Context, id int64) error {
+ r.clearedErrorID = id
+ return nil
+}
+
+func (r *openAIAccountTestRepo) SetError(_ context.Context, id int64, errorMsg string) error {
+ r.setErrorID = id
+ r.setErrorMsg = errorMsg
+ return nil
+}
+
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
@@ -103,7 +117,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
Credentials: map[string]any{"access_token": "test-token"},
}
- err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.NoError(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
@@ -111,11 +125,36 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
require.Contains(t, recorder.Body.String(), "test_complete")
}
-func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing.T) {
+func TestAccountTestService_OpenAIStreamEOFBeforeCompletedFails(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, recorder := newTestContext()
+
+ resp := newJSONResponse(http.StatusOK, "")
+ resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.output_text.delta","delta":"hi"}
+
+`))
+
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 90,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
+ require.Error(t, err)
+ require.Contains(t, recorder.Body.String(), "response.completed")
+ require.NotContains(t, recorder.Body.String(), `"success":true`)
+}
+
+func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newTestContext()
- resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
+ resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":1777283883}}`)
resp.Header.Set("x-codex-primary-used-percent", "100")
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
resp.Header.Set("x-codex-primary-window-minutes", "10080")
@@ -130,15 +169,132 @@ func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing
ID: 88,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
+ Status: StatusError,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
- err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
+ require.Equal(t, account.ID, repo.rateLimitedID)
+ require.NotNil(t, repo.rateLimitedAt)
+ require.Equal(t, account.ID, repo.clearedErrorID)
+ require.Equal(t, StatusActive, account.Status)
+ require.Empty(t, account.ErrorMessage)
+ require.NotNil(t, account.RateLimitResetAt)
+}
+
+func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":"1777283883"}}`)
+
+ repo := &openAIAccountTestRepo{}
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
+ account := &Account{
+ ID: 77,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusError,
+ ErrorMessage: "Access forbidden (403): account may be suspended or lack permissions",
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
+ require.Error(t, err)
+ require.Equal(t, account.ID, repo.rateLimitedID)
+ require.NotNil(t, repo.rateLimitedAt)
+ require.Equal(t, account.ID, repo.clearedErrorID)
+ require.Equal(t, StatusActive, account.Status)
+ require.Empty(t, account.ErrorMessage)
+ require.NotNil(t, account.RateLimitResetAt)
+ require.Empty(t, repo.updatedExtra)
+}
+
+func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_in_seconds":3600}}`)
+
+ repo := &openAIAccountTestRepo{}
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
+ account := &Account{
+ ID: 78,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
+ require.Error(t, err)
+ require.Equal(t, account.ID, repo.rateLimitedID)
+ require.NotNil(t, repo.rateLimitedAt)
+ require.Zero(t, repo.clearedErrorID)
+ require.Equal(t, StatusActive, account.Status)
+ require.NotNil(t, account.RateLimitResetAt)
+}
+
+func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
+
+ repo := &openAIAccountTestRepo{}
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
+ account := &Account{
+ ID: 79,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusError,
+ ErrorMessage: "stale 403",
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
+ require.Error(t, err)
require.Zero(t, repo.rateLimitedID)
require.Nil(t, repo.rateLimitedAt)
+ require.Zero(t, repo.clearedErrorID)
+ require.Equal(t, StatusError, account.Status)
+ require.Equal(t, "stale 403", account.ErrorMessage)
+ require.Nil(t, account.RateLimitResetAt)
+}
+
+func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusUnauthorized, `{"error":"bad token"}`)
+
+ repo := &openAIAccountTestRepo{}
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
+ account := &Account{
+ ID: 80,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
+ require.Error(t, err)
+ require.Equal(t, account.ID, repo.setErrorID)
+ require.Contains(t, repo.setErrorMsg, "Authentication failed (401)")
+ require.Zero(t, repo.rateLimitedID)
+ require.Zero(t, repo.clearedErrorID)
require.Nil(t, account.RateLimitResetAt)
}
diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go
index 8d5bcec8..68ba8f8c 100644
--- a/backend/internal/service/account_usage_service.go
+++ b/backend/internal/service/account_usage_service.go
@@ -110,7 +110,7 @@ const (
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute
- openAICodexProbeVersion = "0.104.0"
+ openAICodexProbeVersion = "0.125.0"
)
// UsageCache 封装账户使用量相关的缓存
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 7c26a47c..d966c684 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -2,15 +2,20 @@ package service
import (
"context"
+ "encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
+ "sort"
+ "strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -29,10 +34,12 @@ type AdminService interface {
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
+ GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
+ BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
@@ -46,10 +53,13 @@ type AdminService interface {
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
+ ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
+ BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// API Key management (admin)
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
+ AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error)
// ReplaceUserGroup 替换用户的专属分组:授予新分组权限、迁移 Key、移除旧分组权限
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
@@ -110,6 +120,7 @@ type CreateUserInput struct {
Notes string
Balance float64
Concurrency int
+ RPMLimit int
AllowedGroups []int64
}
@@ -120,6 +131,7 @@ type UpdateUserInput struct {
Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
+ RPMLimit *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
@@ -127,6 +139,44 @@ type UpdateUserInput struct {
GroupRates map[int64]*float64
}
+type AdminBindAuthIdentityInput struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+ Issuer *string
+ Metadata map[string]any
+ Channel *AdminBindAuthIdentityChannelInput
+}
+
+type AdminBindAuthIdentityChannelInput struct {
+ Channel string
+ ChannelAppID string
+ ChannelSubject string
+ Metadata map[string]any
+}
+
+type AdminBoundAuthIdentity struct {
+ UserID int64 `json:"user_id"`
+ ProviderType string `json:"provider_type"`
+ ProviderKey string `json:"provider_key"`
+ ProviderSubject string `json:"provider_subject"`
+ VerifiedAt *time.Time `json:"verified_at,omitempty"`
+ Issuer *string `json:"issuer,omitempty"`
+ Metadata map[string]any `json:"metadata"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+ Channel *AdminBoundAuthIdentityChannel `json:"channel,omitempty"`
+}
+
+type AdminBoundAuthIdentityChannel struct {
+ Channel string `json:"channel"`
+ ChannelAppID string `json:"channel_app_id"`
+ ChannelSubject string `json:"channel_subject"`
+ Metadata map[string]any `json:"metadata"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
type CreateGroupInput struct {
Name string
Description string
@@ -157,6 +207,8 @@ type CreateGroupInput struct {
RequireOAuthOnly bool
RequirePrivacySet bool
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
+ // RPMLimit 分组 RPM 上限(0 = 不限制)
+ RPMLimit int
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
@@ -192,6 +244,8 @@ type UpdateGroupInput struct {
RequireOAuthOnly *bool
RequirePrivacySet *bool
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
+ // RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。
+ RPMLimit *int
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
@@ -239,6 +293,7 @@ type UpdateAccountInput struct {
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
AccountIDs []int64
+ Filters *BulkUpdateAccountFilters
Name string
ProxyID *int64
Concurrency *int
@@ -255,6 +310,15 @@ type BulkUpdateAccountsInput struct {
SkipMixedChannelCheck bool
}
+type BulkUpdateAccountFilters struct {
+ Platform string
+ Type string
+ Status string
+ Group string
+ Search string
+ PrivacyMode string
+}
+
// BulkUpdateAccountResult captures the result for a single account update.
type BulkUpdateAccountResult struct {
AccountID int64 `json:"account_id"`
@@ -275,6 +339,22 @@ type ReplaceUserGroupResult struct {
MigratedKeys int64 // 迁移的 Key 数量
}
+// UserRPMStatus describes a user's current per-minute RPM usage.
+type UserRPMStatus struct {
+ UserRPMUsed int `json:"user_rpm_used"`
+ UserRPMLimit int `json:"user_rpm_limit"`
+ PerGroup []UserGroupRPMStatus `json:"per_group"`
+}
+
+// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair.
+type UserGroupRPMStatus struct {
+ GroupID int64 `json:"group_id"`
+ GroupName string `json:"group_name"`
+ Used int `json:"used"`
+ Limit int `json:"limit"`
+ Source string `json:"source"` // "group" | "override"
+}
+
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
@@ -421,6 +501,8 @@ const (
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
)
+var ErrRPMStatusUnavailable = infraerrors.New(http.StatusNotImplemented, "RPM_STATUS_UNAVAILABLE", "RPM cache not available")
+
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo UserRepository
@@ -430,6 +512,7 @@ type adminServiceImpl struct {
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository
+ userRPMCache UserRPMCache
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
@@ -454,6 +537,7 @@ func NewAdminService(
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository,
+ userRPMCache UserRPMCache,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
@@ -472,6 +556,7 @@ func NewAdminService(
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo,
+ userRPMCache: userRPMCache,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
@@ -491,6 +576,20 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
if err != nil {
return nil, 0, err
}
+ if len(users) > 0 {
+ userIDs := make([]int64, 0, len(users))
+ for i := range users {
+ userIDs = append(userIDs, users[i].ID)
+ }
+ lastUsedByUserID, latestErr := s.userRepo.GetLatestUsedAtByUserIDs(ctx, userIDs)
+ if latestErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to load user last_used_at in batch: err=%v", latestErr)
+ } else {
+ for i := range users {
+ users[i].LastUsedAt = lastUsedByUserID[users[i].ID]
+ }
+ }
+ }
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
@@ -535,6 +634,12 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
if err != nil {
return nil, err
}
+ lastUsedAt, latestErr := s.userRepo.GetLatestUsedAtByUserID(ctx, id)
+ if latestErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to load user last_used_at: user_id=%d err=%v", id, latestErr)
+ } else {
+ user.LastUsedAt = lastUsedAt
+ }
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
@@ -555,6 +660,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
+ RPMLimit: input.RPMLimit,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
}
@@ -586,6 +692,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
+ // 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率)
+ if input.GroupRates != nil {
+ for groupID, rate := range input.GroupRates {
+ if rate != nil && *rate <= 0 {
+ return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID)
+ }
+ }
+ }
+
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
@@ -599,6 +714,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
+ oldRPMLimit := user.RPMLimit
if input.Email != "" {
user.Email = input.Email
@@ -624,6 +740,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.Concurrency = *input.Concurrency
}
+ if input.RPMLimit != nil {
+ user.RPMLimit = *input.RPMLimit
+ }
+
if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups
}
@@ -640,7 +760,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
}
if s.authCacheInvalidator != nil {
- if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
+ // RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联,
+ // 不失效缓存会让修改在一个 L2 TTL 内失去效果。
+ if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
}
}
@@ -762,6 +884,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
return keys, result.Total, nil
}
+func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) {
+ if s.userRPMCache == nil {
+ return nil, ErrRPMStatusUnavailable
+ }
+
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+
+ userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID)
+ if err != nil {
+ logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err)
+ }
+
+ keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "")
+ if err != nil {
+ return nil, err
+ }
+
+ groupIDSet := make(map[int64]struct{})
+ for _, key := range keys {
+ if key.GroupID != nil && *key.GroupID > 0 {
+ groupIDSet[*key.GroupID] = struct{}{}
+ }
+ }
+
+ groupIDs := make([]int64, 0, len(groupIDSet))
+ for groupID := range groupIDSet {
+ groupIDs = append(groupIDs, groupID)
+ }
+ sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] })
+
+ var perGroup []UserGroupRPMStatus
+ for _, groupID := range groupIDs {
+ used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID)
+ if getErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr)
+ }
+
+ entry := UserGroupRPMStatus{
+ GroupID: groupID,
+ Used: used,
+ }
+
+ if s.groupRepo != nil {
+ if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil {
+ entry.GroupName = group.Name
+ entry.Limit = group.RPMLimit
+ entry.Source = "group"
+ } else if groupErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr)
+ }
+ }
+
+ if s.userGroupRateRepo != nil {
+ override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID)
+ if overrideErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr)
+ } else if override != nil {
+ entry.Limit = *override
+ entry.Source = "override"
+ }
+ }
+
+ perGroup = append(perGroup, entry)
+ }
+
+ return &UserRPMStatus{
+ UserRPMUsed: userRPMUsed,
+ UserRPMLimit: user.RPMLimit,
+ PerGroup: perGroup,
+ }, nil
+}
+
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
// Return mock data for now
return map[string]any{
@@ -788,6 +985,334 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
return codes, result.Total, totalRecharged, nil
}
+func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
+ if userID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0")
+ }
+ if s == nil || s.entClient == nil || s.userRepo == nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable")
+ }
+ if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
+ return nil, err
+ }
+
+ providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType)
+ providerKey := strings.TrimSpace(input.ProviderKey)
+ providerSubject := strings.TrimSpace(input.ProviderSubject)
+ if providerType == "" {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat")
+ }
+ if providerKey == "" || providerSubject == "" {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
+ }
+ canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey)
+ compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey)
+
+ var issuer *string
+ if input.Issuer != nil {
+ trimmed := strings.TrimSpace(*input.Issuer)
+ if trimmed != "" {
+ issuer = &trimmed
+ }
+ }
+
+ channelInput := normalizeAdminBindChannelInput(input.Channel)
+ if input.Channel != nil && channelInput == nil {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided")
+ }
+
+ verifiedAt := time.Now().UTC()
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ identityRecords, err := tx.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(compatibleProviderKeys...),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ ).
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ identity := selectOwnedAdminAuthIdentity(identityRecords, userID)
+
+ if identity == nil {
+ create := tx.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(providerType).
+ SetProviderKey(canonicalProviderKey).
+ SetProviderSubject(providerSubject).
+ SetVerifiedAt(verifiedAt)
+ if issuer != nil {
+ create = create.SetIssuer(*issuer)
+ }
+ if input.Metadata != nil {
+ create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
+ }
+ identity, err = create.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
+ }
+ } else {
+ update := tx.AuthIdentity.UpdateOneID(identity.ID).
+ SetVerifiedAt(verifiedAt).
+ SetProviderKey(canonicalProviderKey)
+ if issuer != nil {
+ update = update.SetIssuer(*issuer)
+ }
+ if input.Metadata != nil {
+ update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
+ }
+ identity, err = update.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
+ }
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if channelInput != nil {
+ channelRecords, err := tx.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(providerType),
+ authidentitychannel.ProviderKeyIn(compatibleProviderKeys...),
+ authidentitychannel.ChannelEQ(channelInput.Channel),
+ authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
+ authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
+ ).
+ WithIdentity().
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
+ }
+ if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID)
+ if channel == nil {
+ create := tx.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(providerType).
+ SetProviderKey(canonicalProviderKey).
+ SetChannel(channelInput.Channel).
+ SetChannelAppID(channelInput.ChannelAppID).
+ SetChannelSubject(channelInput.ChannelSubject)
+ if channelInput.Metadata != nil {
+ create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
+ }
+ channel, err = create.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
+ }
+ } else {
+ update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).
+ SetIdentityID(identity.ID).
+ SetProviderKey(canonicalProviderKey)
+ if channelInput.Metadata != nil {
+ update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
+ }
+ channel, err = update.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
+ }
+ }
+ }
+
+ if err := tx.Commit(); err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err)
+ }
+ return buildAdminBoundAuthIdentity(identity, channel), nil
+}
+
+func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ providerKey = strings.TrimSpace(providerKey)
+ if providerKey == "" {
+ return []string{providerKey}
+ }
+ if providerType != "wechat" {
+ return []string{providerKey}
+ }
+
+ keys := []string{providerKey}
+ if !strings.EqualFold(providerKey, "wechat-main") {
+ keys = append(keys, "wechat-main")
+ }
+ if !strings.EqualFold(providerKey, "wechat") {
+ keys = append(keys, "wechat")
+ }
+ return keys
+}
+
+func canonicalAdminAuthIdentityProviderKey(providerType, existingKey, requestedKey string) string {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ existingKey = strings.TrimSpace(existingKey)
+ requestedKey = strings.TrimSpace(requestedKey)
+ if providerType != "wechat" {
+ if requestedKey != "" {
+ return requestedKey
+ }
+ return existingKey
+ }
+ if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
+ return "wechat-main"
+ }
+ if requestedKey != "" {
+ return requestedKey
+ }
+ return existingKey
+}
+
+func adminAuthIdentityProviderKeyRank(providerType, providerKey string) int {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ providerKey = strings.TrimSpace(providerKey)
+ if providerType != "wechat" {
+ return 0
+ }
+ switch {
+ case strings.EqualFold(providerKey, "wechat-main"):
+ return 0
+ case strings.EqualFold(providerKey, "wechat"):
+ return 2
+ default:
+ return 1
+ }
+}
+
+func selectOwnedAdminAuthIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
+ var selected *dbent.AuthIdentity
+ for _, record := range records {
+ if record.UserID != userID {
+ continue
+ }
+ if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
+ selected = record
+ }
+ }
+ return selected
+}
+
+func hasAdminAuthIdentityOwnershipConflict(records []*dbent.AuthIdentity, userID int64) bool {
+ for _, record := range records {
+ if record.UserID != userID {
+ return true
+ }
+ }
+ return false
+}
+
+func selectOwnedAdminAuthIdentityChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
+ var selected *dbent.AuthIdentityChannel
+ for _, record := range records {
+ if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
+ continue
+ }
+ if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
+ selected = record
+ }
+ }
+ return selected
+}
+
+func hasAdminAuthIdentityChannelOwnershipConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
+ for _, record := range records {
+ if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
+ return true
+ }
+ }
+ return false
+}
+
+func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
+ if input == nil {
+ return nil
+ }
+ channel := &AdminBindAuthIdentityChannelInput{
+ Channel: strings.TrimSpace(input.Channel),
+ ChannelAppID: strings.TrimSpace(input.ChannelAppID),
+ ChannelSubject: strings.TrimSpace(input.ChannelSubject),
+ Metadata: cloneAdminAuthIdentityMetadata(input.Metadata),
+ }
+ if channel.Channel == "" || channel.ChannelAppID == "" || channel.ChannelSubject == "" {
+ return nil
+ }
+ return channel
+}
+
+func normalizeAdminAuthIdentityProviderType(input string) string {
+ switch strings.ToLower(strings.TrimSpace(input)) {
+ case "email":
+ return "email"
+ case "linuxdo":
+ return "linuxdo"
+ case "oidc":
+ return "oidc"
+ case "wechat":
+ return "wechat"
+ default:
+ return ""
+ }
+}
+
+func buildAdminBoundAuthIdentity(identity *dbent.AuthIdentity, channel *dbent.AuthIdentityChannel) *AdminBoundAuthIdentity {
+ if identity == nil {
+ return nil
+ }
+ result := &AdminBoundAuthIdentity{
+ UserID: identity.UserID,
+ ProviderType: strings.TrimSpace(identity.ProviderType),
+ ProviderKey: strings.TrimSpace(identity.ProviderKey),
+ ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
+ VerifiedAt: identity.VerifiedAt,
+ Issuer: identity.Issuer,
+ Metadata: cloneAdminAuthIdentityMetadata(identity.Metadata),
+ CreatedAt: identity.CreatedAt,
+ UpdatedAt: identity.UpdatedAt,
+ }
+ if channel != nil {
+ result.Channel = &AdminBoundAuthIdentityChannel{
+ Channel: strings.TrimSpace(channel.Channel),
+ ChannelAppID: strings.TrimSpace(channel.ChannelAppID),
+ ChannelSubject: strings.TrimSpace(channel.ChannelSubject),
+ Metadata: cloneAdminAuthIdentityMetadata(channel.Metadata),
+ CreatedAt: channel.CreatedAt,
+ UpdatedAt: channel.UpdatedAt,
+ }
+ }
+ return result
+}
+
+func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any {
+ if input == nil {
+ return nil
+ }
+ if len(input) == 0 {
+ return map[string]any{}
+ }
+ data, err := json.Marshal(input)
+ if err != nil {
+ out := make(map[string]any, len(input))
+ for key, value := range input {
+ out[key] = value
+ }
+ return out
+ }
+ var out map[string]any
+ if err := json.Unmarshal(data, &out); err != nil {
+ out = make(map[string]any, len(input))
+ for key, value := range input {
+ out[key] = value
+ }
+ }
+ return out
+}
+
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
@@ -811,6 +1336,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro
}
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
+ if input.RateMultiplier <= 0 {
+ return nil, errors.New("rate_multiplier must be > 0")
+ }
+
platform := input.Platform
if platform == "" {
platform = PlatformAnthropic
@@ -911,6 +1440,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequirePrivacySet: input.RequirePrivacySet,
DefaultMappedModel: input.DefaultMappedModel,
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
+ RPMLimit: input.RPMLimit,
}
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Create(ctx, group); err != nil {
@@ -1050,6 +1580,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.Platform = input.Platform
}
if input.RateMultiplier != nil {
+ if *input.RateMultiplier <= 0 {
+ return nil, errors.New("rate_multiplier must be > 0")
+ }
group.RateMultiplier = *input.RateMultiplier
}
if input.IsExclusive != nil {
@@ -1142,12 +1675,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.MessagesDispatchModelConfig != nil {
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
}
+ if input.RPMLimit != nil {
+ group.RPMLimit = *input.RPMLimit
+ }
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
+ }
+
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
@@ -1216,9 +1756,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
}
- if s.authCacheInvalidator != nil {
- s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
- }
return group, nil
}
@@ -1286,9 +1823,47 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
if s.userGroupRateRepo == nil {
return nil
}
+ for _, e := range entries {
+ if e.RateMultiplier <= 0 {
+ return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID)
+ }
+ }
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
}
+func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
+ if s.userGroupRateRepo == nil {
+ return nil
+ }
+ if err := s.userGroupRateRepo.ClearGroupRPMOverrides(ctx, groupID); err != nil {
+ return err
+ }
+ // RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
+ }
+ return nil
+}
+
+func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
+ if s.userGroupRateRepo == nil {
+ return nil
+ }
+ for _, e := range entries {
+ if e.RPMOverride != nil && *e.RPMOverride < 0 {
+ return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID))
+ }
+ }
+ if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil {
+ return err
+ }
+ // RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
+ }
+ return nil
+}
+
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return s.groupRepo.UpdateSortOrders(ctx, updates)
}
@@ -1398,6 +1973,30 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
return result, nil
}
+// AdminResetAPIKeyRateLimitUsage resets all API key rate-limit usage windows.
+func (s *adminServiceImpl) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error) {
+ apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
+ if err != nil {
+ return nil, err
+ }
+ apiKey.Usage5h = 0
+ apiKey.Usage1d = 0
+ apiKey.Usage7d = 0
+ apiKey.Window5hStart = nil
+ apiKey.Window1dStart = nil
+ apiKey.Window7dStart = nil
+ if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
+ return nil, fmt.Errorf("reset api key rate limit usage: %w", err)
+ }
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
+ }
+ if s.billingCacheService != nil {
+ _ = s.billingCacheService.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
+ }
+ return apiKey, nil
+}
+
// ReplaceUserGroup 替换用户的专属分组
func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) {
if oldGroupID == newGroupID {
@@ -1723,6 +2322,14 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
+ if len(input.AccountIDs) == 0 && input.Filters != nil {
+ accountIDs, err := s.resolveBulkUpdateTargetIDs(ctx, input.Filters)
+ if err != nil {
+ return nil, err
+ }
+ input.AccountIDs = accountIDs
+ }
+
result := &BulkUpdateAccountsResult{
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
@@ -1838,6 +2445,55 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
return result, nil
}
+func (s *adminServiceImpl) resolveBulkUpdateTargetIDs(ctx context.Context, filters *BulkUpdateAccountFilters) ([]int64, error) {
+ if filters == nil {
+ return nil, nil
+ }
+
+ groupID := int64(0)
+ switch strings.TrimSpace(filters.Group) {
+ case "":
+ case "ungrouped":
+ groupID = AccountListGroupUngrouped
+ default:
+ parsedGroupID, err := strconv.ParseInt(strings.TrimSpace(filters.Group), 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("invalid group filter: %w", err)
+ }
+ groupID = parsedGroupID
+ }
+
+ const pageSize = 500
+ page := 1
+ accountIDs := make([]int64, 0, pageSize)
+
+ for {
+ accounts, total, err := s.ListAccounts(
+ ctx,
+ page,
+ pageSize,
+ filters.Platform,
+ filters.Type,
+ filters.Status,
+ filters.Search,
+ groupID,
+ filters.PrivacyMode,
+ "",
+ "",
+ )
+ if err != nil {
+ return nil, err
+ }
+ for _, account := range accounts {
+ accountIDs = append(accountIDs, account.ID)
+ }
+ if int64(len(accountIDs)) >= total || len(accounts) == 0 {
+ return accountIDs, nil
+ }
+ page++
+ }
+}
+
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
if err := s.accountRepo.Delete(ctx, id); err != nil {
return err
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index 419ddbc3..fcde5cbf 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro
}
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
@@ -70,6 +79,23 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s
}
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ panic("unexpected")
+}
+
+func (s *userRepoStubForGroupUpdate) UnbindUserAuthProvider(context.Context, int64, string) error {
+ panic("unexpected")
+}
+
+func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
diff --git a/backend/internal/service/admin_service_auth_identity_binding_test.go b/backend/internal/service/admin_service_auth_identity_binding_test.go
new file mode 100644
index 00000000..719199f2
--- /dev/null
+++ b/backend/internal/service/admin_service_auth_identity_binding_test.go
@@ -0,0 +1,302 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newAdminServiceAuthIdentityBindingTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:admin_service_auth_identity_binding?mode=memory&cache=shared&_fk=1")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
+
+func TestAdminServiceBindUserAuthIdentityCreatesCanonicalAndChannelBinding(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("bind-target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-123",
+ Metadata: map[string]any{"source": "admin-repair"},
+ Channel: &AdminBindAuthIdentityChannelInput{
+ Channel: "open",
+ ChannelAppID: "wx-open",
+ ChannelSubject: "openid-123",
+ Metadata: map[string]any{"scene": "migration"},
+ },
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, user.ID, result.UserID)
+ require.Equal(t, "wechat", result.ProviderType)
+ require.Equal(t, "wechat-main", result.ProviderKey)
+ require.NotNil(t, result.VerifiedAt)
+ require.NotNil(t, result.Channel)
+ require.Equal(t, "open", result.Channel.Channel)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ("wechat-main"),
+ authidentity.ProviderSubjectEQ("union-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+ require.Equal(t, "admin-repair", identity.Metadata["source"])
+ require.NotNil(t, identity.VerifiedAt)
+
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ("wechat-main"),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, identity.ID, channel.IdentityID)
+ require.Equal(t, "migration", channel.Metadata["scene"])
+}
+
+func TestAdminServiceBindUserAuthIdentityRejectsOtherOwner(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ target, err := client.User.Create().
+ SetEmail("target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("subject-1").
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: target.ID, Email: target.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ _, err = svc.BindUserAuthIdentity(ctx, target.ID, AdminBindAuthIdentityInput{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-1",
+ })
+ require.Error(t, err)
+ require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", infraerrors.Reason(err))
+}
+
+func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("same-user@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ first, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-2",
+ Metadata: map[string]any{"source": "first"},
+ })
+ require.NoError(t, err)
+
+ second, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-2",
+ Metadata: map[string]any{"source": "second"},
+ })
+ require.NoError(t, err)
+ require.Equal(t, first.UserID, second.UserID)
+ require.Equal(t, "second", second.Metadata["source"])
+
+ identities, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("subject-2"),
+ ).
+ All(ctx)
+ require.NoError(t, err)
+ require.Len(t, identities, 1)
+ require.Equal(t, "second", identities[0].Metadata["source"])
+}
+
+func TestAdminServiceBindUserAuthIdentityReusesLegacyWeChatAliasRecords(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("wechat-alias@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetProviderSubject("union-legacy-123").
+ SetMetadata(map[string]any{"source": "legacy"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyChannel, err := client.AuthIdentityChannel.Create().
+ SetIdentityID(legacyIdentity.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetChannel("open").
+ SetChannelAppID("wx-open").
+ SetChannelSubject("openid-legacy-123").
+ SetMetadata(map[string]any{"scene": "legacy"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-legacy-123",
+ Metadata: map[string]any{"source": "admin-repair"},
+ Channel: &AdminBindAuthIdentityChannelInput{
+ Channel: "open",
+ ChannelAppID: "wx-open",
+ ChannelSubject: "openid-legacy-123",
+ Metadata: map[string]any{"scene": "admin-repair"},
+ },
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "wechat-main", result.ProviderKey)
+ require.NotNil(t, result.Channel)
+ require.Equal(t, "open", result.Channel.Channel)
+
+ identity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "wechat-main", identity.ProviderKey)
+ require.Equal(t, "admin-repair", identity.Metadata["source"])
+
+ channel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
+ require.NoError(t, err)
+ require.Equal(t, "wechat-main", channel.ProviderKey)
+ require.Equal(t, legacyIdentity.ID, channel.IdentityID)
+ require.Equal(t, "admin-repair", channel.Metadata["scene"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderSubjectEQ("union-legacy-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+
+ channelCount, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open"),
+ authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, channelCount)
+}
+
+func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("invalid-provider@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ _, err = svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "github",
+ ProviderKey: "github-main",
+ ProviderSubject: "subject-3",
+ })
+ require.Error(t, err)
+ require.Equal(t, "INVALID_INPUT", infraerrors.Reason(err))
+}
diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go
index 4845d87c..df415295 100644
--- a/backend/internal/service/admin_service_bulk_update_test.go
+++ b/backend/internal/service/admin_service_bulk_update_test.go
@@ -5,8 +5,10 @@ package service
import (
"context"
"errors"
+ "reflect"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
@@ -25,6 +27,19 @@ type accountRepoStubForBulkUpdate struct {
getByIDCalled []int64
listByGroupData map[int64][]Account
listByGroupErr map[int64]error
+ listData []Account
+ listResult *pagination.PaginationResult
+ listErr error
+ listCalled bool
+ lastListParams pagination.PaginationParams
+ lastListFilters struct {
+ platform string
+ accountType string
+ status string
+ search string
+ groupID int64
+ privacyMode string
+ }
}
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
@@ -73,6 +88,24 @@ func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID in
return nil, nil
}
+func (s *accountRepoStubForBulkUpdate) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, *pagination.PaginationResult, error) {
+ s.listCalled = true
+ s.lastListParams = params
+ s.lastListFilters.platform = platform
+ s.lastListFilters.accountType = accountType
+ s.lastListFilters.status = status
+ s.lastListFilters.search = search
+ s.lastListFilters.groupID = groupID
+ s.lastListFilters.privacyMode = privacyMode
+ if s.listErr != nil {
+ return nil, nil, s.listErr
+ }
+ if s.listResult != nil {
+ return s.listData, s.listResult, nil
+ }
+ return s.listData, &pagination.PaginationResult{Total: int64(len(s.listData))}, nil
+}
+
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
@@ -170,3 +203,46 @@ func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingCon
// No BindGroups should have been called since the check runs before any write.
require.Empty(t, repo.bindGroupsCalls)
}
+
+func TestAdminServiceBulkUpdateAccounts_ResolvesIDsFromFilters(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{
+ listData: []Account{
+ {ID: 7},
+ {ID: 11},
+ },
+ listResult: &pagination.PaginationResult{Total: 2},
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ schedulable := true
+ input := &BulkUpdateAccountsInput{
+ Schedulable: &schedulable,
+ }
+
+ filtersField := reflect.ValueOf(input).Elem().FieldByName("Filters")
+ require.True(t, filtersField.IsValid(), "BulkUpdateAccountsInput should expose Filters for filter-target bulk update")
+ require.Equal(t, reflect.Ptr, filtersField.Kind(), "BulkUpdateAccountsInput.Filters should be a pointer field")
+
+ filtersValue := reflect.New(filtersField.Type().Elem())
+ filtersValue.Elem().FieldByName("Platform").SetString(PlatformOpenAI)
+ filtersValue.Elem().FieldByName("Type").SetString(AccountTypeOAuth)
+ filtersValue.Elem().FieldByName("Status").SetString(StatusActive)
+ filtersValue.Elem().FieldByName("Group").SetString("12")
+ filtersValue.Elem().FieldByName("PrivacyMode").SetString(PrivacyModeCFBlocked)
+ filtersValue.Elem().FieldByName("Search").SetString("bulk-target")
+ filtersField.Set(filtersValue)
+
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.NoError(t, err)
+ require.True(t, repo.listCalled, "expected filter-target bulk update to resolve matching IDs via account list filters")
+ require.Equal(t, PlatformOpenAI, repo.lastListFilters.platform)
+ require.Equal(t, AccountTypeOAuth, repo.lastListFilters.accountType)
+ require.Equal(t, StatusActive, repo.lastListFilters.status)
+ require.Equal(t, "bulk-target", repo.lastListFilters.search)
+ require.Equal(t, int64(12), repo.lastListFilters.groupID)
+ require.Equal(t, PrivacyModeCFBlocked, repo.lastListFilters.privacyMode)
+ require.Equal(t, []int64{7, 11}, repo.bulkUpdateIDs)
+ require.Equal(t, 2, result.Success)
+ require.Equal(t, 0, result.Failed)
+ require.Equal(t, []int64{7, 11}, result.SuccessIDs)
+}
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index fbc856cf..fe9e7701 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -13,15 +13,18 @@ import (
)
type userRepoStub struct {
- user *User
- getErr error
- createErr error
- deleteErr error
- exists bool
- existsErr error
- nextID int64
- created []*User
- deletedIDs []int64
+ user *User
+ getErr error
+ createErr error
+ deleteErr error
+ exists bool
+ existsErr error
+ nextID int64
+ created []*User
+ updated []*User
+ deletedIDs []int64
+ usersByEmail map[string]*User
+ getByEmailErr error
}
func (s *userRepoStub) Create(ctx context.Context, user *User) error {
@@ -32,6 +35,11 @@ func (s *userRepoStub) Create(ctx context.Context, user *User) error {
user.ID = s.nextID
}
s.created = append(s.created, user)
+ if s.usersByEmail == nil {
+ s.usersByEmail = make(map[string]*User)
+ }
+ s.usersByEmail[user.Email] = user
+ s.user = user
return nil
}
@@ -46,7 +54,18 @@ func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
}
func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) {
- panic("unexpected GetByEmail call")
+ if s.getByEmailErr != nil {
+ return nil, s.getByEmailErr
+ }
+ if s.usersByEmail != nil {
+ if user, ok := s.usersByEmail[email]; ok {
+ return user, nil
+ }
+ }
+ if s.user != nil && s.user.Email == email {
+ return s.user, nil
+ }
+ return nil, ErrUserNotFound
}
func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
@@ -54,7 +73,13 @@ func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
}
func (s *userRepoStub) Update(ctx context.Context, user *User) error {
- panic("unexpected Update call")
+ s.updated = append(s.updated, user)
+ if s.usersByEmail == nil {
+ s.usersByEmail = make(map[string]*User)
+ }
+ s.usersByEmail[user.Email] = user
+ s.user = user
+ return nil
}
func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
@@ -62,6 +87,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
return s.deleteErr
}
+func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
+ panic("unexpected GetUserAvatar call")
+}
+
+func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
@@ -70,6 +107,18 @@ func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
+func (s *userRepoStub) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserIDs call")
+}
+
+func (s *userRepoStub) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserID call")
+}
+
+func (s *userRepoStub) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ panic("unexpected UpdateUserLastActiveAt call")
+}
+
func (s *userRepoStub) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
@@ -101,6 +150,14 @@ func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
+func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
+ panic("unexpected ListUserAuthIdentities call")
+}
+
+func (s *userRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
+ panic("unexpected UnbindUserAuthProvider call")
+}
+
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go
new file mode 100644
index 00000000..2232c9c3
--- /dev/null
+++ b/backend/internal/service/admin_service_email_identity_sync_test.go
@@ -0,0 +1,187 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type ensureEmailCall struct {
+ userID int64
+ email string
+}
+
+type replaceEmailCall struct {
+ userID int64
+ oldEmail string
+ newEmail string
+}
+
+type emailSyncRepoStub struct {
+ user *User
+ nextID int64
+ updateCalls int
+ created []*User
+ updated []*User
+ ensureCalls []ensureEmailCall
+ replaceCalls []replaceEmailCall
+ ensureErr error
+ replaceErr error
+}
+
+func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error {
+ if s.nextID != 0 && user.ID == 0 {
+ user.ID = s.nextID
+ }
+ s.created = append(s.created, user)
+ s.user = user
+ return nil
+}
+
+func (s *emailSyncRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
+ if s.user == nil {
+ return nil, ErrUserNotFound
+ }
+ cloned := *s.user
+ return &cloned, nil
+}
+
+func (s *emailSyncRepoStub) GetByEmail(_ context.Context, _ string) (*User, error) {
+ return nil, ErrUserNotFound
+}
+
+func (s *emailSyncRepoStub) GetFirstAdmin(context.Context) (*User, error) {
+ return nil, fmt.Errorf("unexpected GetFirstAdmin call")
+}
+
+func (s *emailSyncRepoStub) Update(_ context.Context, user *User) error {
+ s.updateCalls++
+ s.updated = append(s.updated, user)
+ s.user = user
+ return nil
+}
+
+func (s *emailSyncRepoStub) Delete(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
+ return nil, fmt.Errorf("unexpected GetUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
+ return nil, fmt.Errorf("unexpected UpsertUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ return fmt.Errorf("unexpected DeleteUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
+ return nil, nil, fmt.Errorf("unexpected List call")
+}
+
+func (s *emailSyncRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
+ return nil, nil, fmt.Errorf("unexpected ListWithFilters call")
+}
+
+func (s *emailSyncRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (s *emailSyncRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (s *emailSyncRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ return nil
+}
+
+func (s *emailSyncRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+
+func (s *emailSyncRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+
+func (s *emailSyncRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+
+func (s *emailSyncRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
+
+func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
+
+func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+func (s *emailSyncRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (s *emailSyncRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { return nil }
+
+func (s *emailSyncRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+
+func (s *emailSyncRepoStub) EnableTotp(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
+ s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
+ return s.ensureErr
+}
+
+func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
+ s.replaceCalls = append(s.replaceCalls, replaceEmailCall{
+ userID: userID,
+ oldEmail: oldEmail,
+ newEmail: newEmail,
+ })
+ return s.replaceErr
+}
+
+func TestAdminService_CreateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
+ repo := &emailSyncRepoStub{
+ nextID: 55,
+ ensureErr: fmt.Errorf("unexpected email resync"),
+ }
+ svc := &adminServiceImpl{userRepo: repo}
+
+ user, err := svc.CreateUser(context.Background(), &CreateUserInput{
+ Email: "admin-created@example.com",
+ Password: "strong-pass",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, int64(55), user.ID)
+ require.Empty(t, repo.ensureCalls)
+ require.Empty(t, repo.replaceCalls)
+}
+
+func TestAdminService_UpdateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
+ repo := &emailSyncRepoStub{
+ user: &User{
+ ID: 91,
+ Email: "before@example.com",
+ Role: RoleUser,
+ Status: StatusActive,
+ Concurrency: 3,
+ },
+ replaceErr: fmt.Errorf("unexpected email resync"),
+ }
+ svc := &adminServiceImpl{userRepo: repo}
+
+ updated, err := svc.UpdateUser(context.Background(), 91, &UpdateUserInput{
+ Email: "after@example.com",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+ require.Equal(t, "after@example.com", updated.Email)
+ require.Empty(t, repo.replaceCalls)
+ require.Empty(t, repo.ensureCalls)
+}
diff --git a/backend/internal/service/admin_service_group_rate_test.go b/backend/internal/service/admin_service_group_rate_test.go
index 77635247..d2efb644 100644
--- a/backend/internal/service/admin_service_group_rate_test.go
+++ b/backend/internal/service/admin_service_group_rate_test.go
@@ -5,8 +5,10 @@ package service
import (
"context"
"errors"
+ "net/http"
"testing"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct {
syncedGroupID int64
syncedEntries []GroupRateMultiplierInput
syncGroupErr error
+
+ rpmSyncedGroupID int64
+ rpmSyncedEntries []GroupRPMOverrideInput
+ rpmSyncErr error
}
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context,
panic("unexpected GetByUserAndGroup call")
}
+func (s *userGroupRateRepoStubForGroupRate) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
+ panic("unexpected GetRPMOverrideByUserAndGroup call")
+}
+
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
if s.getByGroupIDErr != nil {
return nil, s.getByGroupIDErr
@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C
return s.syncGroupErr
}
+func (s *userGroupRateRepoStubForGroupRate) SyncGroupRPMOverrides(_ context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
+ s.rpmSyncedGroupID = groupID
+ s.rpmSyncedEntries = entries
+ return s.rpmSyncErr
+}
+
+func (s *userGroupRateRepoStubForGroupRate) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
+ panic("unexpected ClearGroupRPMOverrides call")
+}
+
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
return s.deleteByGroupErr
@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{
10: {
- {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
- {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
+ {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: ptrFloat(1.5)},
+ {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: ptrFloat(0.8)},
},
},
}
@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
require.Len(t, entries, 2)
require.Equal(t, int64(1), entries[0].UserID)
require.Equal(t, "alice", entries[0].UserName)
- require.Equal(t, 1.5, entries[0].RateMultiplier)
+ require.NotNil(t, entries[0].RateMultiplier)
+ require.Equal(t, 1.5, *entries[0].RateMultiplier)
require.Equal(t, int64(2), entries[1].UserID)
- require.Equal(t, 0.8, entries[1].RateMultiplier)
+ require.NotNil(t, entries[1].RateMultiplier)
+ require.Equal(t, 0.8, *entries[1].RateMultiplier)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
require.Contains(t, err.Error(), "sync failed")
})
}
+
+func TestAdminService_BatchSetGroupRPMOverrides(t *testing.T) {
+ t.Run("syncs entries to repo", func(t *testing.T) {
+ repo := &userGroupRateRepoStubForGroupRate{}
+ svc := &adminServiceImpl{userGroupRateRepo: repo}
+ override := 20
+ entries := []GroupRPMOverrideInput{{UserID: 2, RPMOverride: &override}}
+
+ err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, entries)
+ require.NoError(t, err)
+ require.Equal(t, int64(10), repo.rpmSyncedGroupID)
+ require.Equal(t, entries, repo.rpmSyncedEntries)
+ })
+
+ t.Run("rejects negative override as bad request", func(t *testing.T) {
+ repo := &userGroupRateRepoStubForGroupRate{}
+ svc := &adminServiceImpl{userGroupRateRepo: repo}
+ negative := -1
+
+ err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, []GroupRPMOverrideInput{
+ {UserID: 2, RPMOverride: &negative},
+ })
+ require.Error(t, err)
+ require.Equal(t, http.StatusBadRequest, infraerrors.Code(err))
+ require.Zero(t, repo.rpmSyncedGroupID)
+ })
+}
diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go
index a4c6d0ca..eef02240 100644
--- a/backend/internal/service/admin_service_group_test.go
+++ b/backend/internal/service/admin_service_group_test.go
@@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K)
}
+func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
+ existingGroup := &Group{
+ ID: 1,
+ Name: "existing-group",
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ RPMLimit: 10,
+ }
+ repo := &groupRepoStubForAdmin{getByID: existingGroup}
+ invalidator := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{
+ groupRepo: repo,
+ authCacheInvalidator: invalidator,
+ }
+
+ rpmLimit := 60
+ group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
+ RPMLimit: &rpmLimit,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, group)
+ require.Equal(t, 60, repo.updated.RPMLimit)
+ require.Equal(t, []int64{1}, invalidator.groupIDs, "分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存")
+}
+
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
@@ -621,6 +646,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformOpenAI,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -641,6 +667,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeSubscription,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -695,6 +722,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -713,6 +741,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -733,6 +762,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAntigravity,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -750,6 +780,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &zero,
})
diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go
index ceeb52c2..ff3f65a8 100644
--- a/backend/internal/service/admin_service_list_users_test.go
+++ b/backend/internal/service/admin_service_list_users_test.go
@@ -6,6 +6,7 @@ import (
"context"
"errors"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
@@ -16,6 +17,8 @@ type userRepoStubForListUsers struct {
users []User
err error
listWithFiltersParams pagination.PaginationParams
+ lastUsedByUserID map[int64]*time.Time
+ lastUsedErr error
}
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
@@ -32,6 +35,26 @@ func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pag
}, nil
}
+func (s *userRepoStubForListUsers) GetLatestUsedAtByUserIDs(_ context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ if s.lastUsedErr != nil {
+ return nil, s.lastUsedErr
+ }
+ result := make(map[int64]*time.Time, len(userIDs))
+ for _, userID := range userIDs {
+ if ts, ok := s.lastUsedByUserID[userID]; ok {
+ result[userID] = ts
+ }
+ }
+ return result, nil
+}
+
+func (s *userRepoStubForListUsers) GetLatestUsedAtByUserID(_ context.Context, userID int64) (*time.Time, error) {
+ if s.lastUsedErr != nil {
+ return nil, s.lastUsedErr
+ }
+ return s.lastUsedByUserID[userID], nil
+}
+
type userGroupRateRepoStubForListUsers struct {
batchCalls int
singleCall []int64
@@ -66,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context,
panic("unexpected GetByUserAndGroup call")
}
+func (s *userGroupRateRepoStubForListUsers) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
+ panic("unexpected GetRPMOverrideByUserAndGroup call")
+}
+
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
@@ -78,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C
panic("unexpected SyncGroupRateMultipliers call")
}
+func (s *userGroupRateRepoStubForListUsers) SyncGroupRPMOverrides(_ context.Context, _ int64, _ []GroupRPMOverrideInput) error {
+ panic("unexpected SyncGroupRPMOverrides call")
+}
+
+func (s *userGroupRateRepoStubForListUsers) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
+ panic("unexpected ClearGroupRPMOverrides call")
+}
+
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
panic("unexpected DeleteByGroupID call")
}
@@ -130,3 +165,21 @@ func TestAdminService_ListUsers_PassesSortParams(t *testing.T) {
SortOrder: "ASC",
}, userRepo.listWithFiltersParams)
}
+
+func TestAdminService_ListUsers_PopulatesLastUsedAt(t *testing.T) {
+ lastUsed := time.Now().UTC().Add(-30 * time.Minute).Truncate(time.Second)
+ userRepo := &userRepoStubForListUsers{
+ users: []User{{ID: 101, Email: "u@example.com"}},
+ lastUsedByUserID: map[int64]*time.Time{
+ 101: &lastUsed,
+ },
+ }
+ svc := &adminServiceImpl{userRepo: userRepo}
+
+ users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "")
+ require.NoError(t, err)
+ require.Equal(t, int64(1), total)
+ require.Len(t, users, 1)
+ require.NotNil(t, users[0].LastUsedAt)
+ require.WithinDuration(t, lastUsed, *users[0].LastUsedAt, time.Second)
+}
diff --git a/backend/internal/service/admin_service_rpm_status_test.go b/backend/internal/service/admin_service_rpm_status_test.go
new file mode 100644
index 00000000..c298f69b
--- /dev/null
+++ b/backend/internal/service/admin_service_rpm_status_test.go
@@ -0,0 +1,112 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type rpmStatusUserRepoStub struct {
+ UserRepository
+ user *User
+}
+
+func (s *rpmStatusUserRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
+ return s.user, nil
+}
+
+type rpmStatusAPIKeyRepoStub struct {
+ APIKeyRepository
+ keys []APIKey
+}
+
+func (s *rpmStatusAPIKeyRepoStub) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
+ return s.keys, &pagination.PaginationResult{Total: int64(len(s.keys))}, nil
+}
+
+type rpmStatusGroupRepoStub struct {
+ GroupRepository
+ groups map[int64]*Group
+}
+
+func (s *rpmStatusGroupRepoStub) GetByIDLite(_ context.Context, id int64) (*Group, error) {
+ return s.groups[id], nil
+}
+
+type rpmStatusRateRepoStub struct {
+ UserGroupRateRepository
+ overrides map[int64]*int
+}
+
+func (s *rpmStatusRateRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, groupID int64) (*int, error) {
+ return s.overrides[groupID], nil
+}
+
+type rpmStatusCacheStub struct {
+ UserRPMCache
+ userUsed int
+ groupUsed map[int64]int
+}
+
+func (s *rpmStatusCacheStub) IncrementUserGroupRPM(context.Context, int64, int64) (int, error) {
+ return 0, nil
+}
+
+func (s *rpmStatusCacheStub) IncrementUserRPM(context.Context, int64) (int, error) {
+ return 0, nil
+}
+
+func (s *rpmStatusCacheStub) GetUserGroupRPM(_ context.Context, _, groupID int64) (int, error) {
+ return s.groupUsed[groupID], nil
+}
+
+func (s *rpmStatusCacheStub) GetUserRPM(context.Context, int64) (int, error) {
+ return s.userUsed, nil
+}
+
+func TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits(t *testing.T) {
+ groupOneID := int64(1)
+ groupTwoID := int64(2)
+ override := 7
+ svc := &adminServiceImpl{
+ userRepo: &rpmStatusUserRepoStub{user: &User{
+ ID: 42,
+ RPMLimit: 20,
+ }},
+ apiKeyRepo: &rpmStatusAPIKeyRepoStub{keys: []APIKey{
+ {ID: 100, UserID: 42, GroupID: &groupTwoID},
+ {ID: 101, UserID: 42, GroupID: &groupOneID},
+ {ID: 102, UserID: 42, GroupID: &groupTwoID},
+ {ID: 103, UserID: 42},
+ }},
+ groupRepo: &rpmStatusGroupRepoStub{groups: map[int64]*Group{
+ groupOneID: {ID: groupOneID, Name: "group-one", RPMLimit: 10},
+ groupTwoID: {ID: groupTwoID, Name: "group-two", RPMLimit: 60},
+ }},
+ userGroupRateRepo: &rpmStatusRateRepoStub{overrides: map[int64]*int{
+ groupTwoID: &override,
+ }},
+ userRPMCache: &rpmStatusCacheStub{
+ userUsed: 5,
+ groupUsed: map[int64]int{
+ groupOneID: 3,
+ groupTwoID: 4,
+ },
+ },
+ }
+
+ status, err := svc.GetUserRPMStatus(context.Background(), 42)
+ require.NoError(t, err)
+ require.Equal(t, &UserRPMStatus{
+ UserRPMUsed: 5,
+ UserRPMLimit: 20,
+ PerGroup: []UserGroupRPMStatus{
+ {GroupID: groupOneID, GroupName: "group-one", Used: 3, Limit: 10, Source: "group"},
+ {GroupID: groupTwoID, GroupName: "group-two", Used: 4, Limit: 7, Source: "override"},
+ },
+ }, status)
+}
diff --git a/backend/internal/service/admin_service_update_user_rpm_test.go b/backend/internal/service/admin_service_update_user_rpm_test.go
new file mode 100644
index 00000000..cb4c3986
--- /dev/null
+++ b/backend/internal/service/admin_service_update_user_rpm_test.go
@@ -0,0 +1,69 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构,
+// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。
+type rpmUserRepoStub struct {
+ *userRepoStub
+ lastUpdated *User
+}
+
+func (s *rpmUserRepoStub) Update(_ context.Context, user *User) error {
+ if user == nil {
+ return nil
+ }
+ clone := *user
+ s.lastUpdated = &clone
+ if s.userRepoStub != nil {
+ s.userRepoStub.user = &clone
+ }
+ return nil
+}
+
+func TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
+ base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10}}
+ repo := &rpmUserRepoStub{userRepoStub: base}
+ invalidator := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{
+ userRepo: repo,
+ redeemCodeRepo: &redeemRepoStub{},
+ authCacheInvalidator: invalidator,
+ }
+
+ newRPM := 60
+ updated, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
+ RPMLimit: &newRPM,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+ require.Equal(t, 60, updated.RPMLimit)
+ require.Equal(t, []int64{42}, invalidator.userIDs, "仅修改 RPMLimit 也应失效 API Key 认证缓存")
+}
+
+func TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged(t *testing.T) {
+ base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10, Username: "old"}}
+ repo := &rpmUserRepoStub{userRepoStub: base}
+ invalidator := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{
+ userRepo: repo,
+ redeemCodeRepo: &redeemRepoStub{},
+ authCacheInvalidator: invalidator,
+ }
+
+ newName := "new"
+ sameRPM := 10
+ _, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
+ Username: &newName,
+ RPMLimit: &sameRPM,
+ })
+ require.NoError(t, err)
+ require.Empty(t, invalidator.userIDs, "只改 username 不应触发认证缓存失效")
+}
diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go
new file mode 100644
index 00000000..5a4e91e7
--- /dev/null
+++ b/backend/internal/service/affiliate_service.go
@@ -0,0 +1,490 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "math"
+ "strings"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+)
+
+var (
+ ErrAffiliateProfileNotFound = infraerrors.NotFound("AFFILIATE_PROFILE_NOT_FOUND", "affiliate profile not found")
+ ErrAffiliateCodeInvalid = infraerrors.BadRequest("AFFILIATE_CODE_INVALID", "invalid affiliate code")
+ ErrAffiliateCodeTaken = infraerrors.Conflict("AFFILIATE_CODE_TAKEN", "affiliate code already in use")
+ ErrAffiliateAlreadyBound = infraerrors.Conflict("AFFILIATE_ALREADY_BOUND", "affiliate inviter already bound")
+ ErrAffiliateQuotaEmpty = infraerrors.BadRequest("AFFILIATE_QUOTA_EMPTY", "no affiliate quota available to transfer")
+)
+
+const (
+ affiliateInviteesLimit = 100
+ // AffiliateCodeMinLength / AffiliateCodeMaxLength bound both system-generated
+ // 12-char codes and admin-customized codes (e.g. "VIP2026").
+ AffiliateCodeMinLength = 4
+ AffiliateCodeMaxLength = 32
+)
+
+// affiliateCodeValidChar accepts uppercase letters, digits, underscore and dash.
+// All input passes through strings.ToUpper before validation, so lowercase from
+// users is normalized — admins may supply mixed case in their UI.
+var affiliateCodeValidChar = func() [256]bool {
+ var tbl [256]bool
+ for c := byte('A'); c <= 'Z'; c++ {
+ tbl[c] = true
+ }
+ for c := byte('0'); c <= '9'; c++ {
+ tbl[c] = true
+ }
+ tbl['_'] = true
+ tbl['-'] = true
+ return tbl
+}()
+
+// isValidAffiliateCodeFormat validates code format for both binding (user input)
+// and admin updates. Caller is expected to upper-case the input first.
+func isValidAffiliateCodeFormat(code string) bool {
+ if len(code) < AffiliateCodeMinLength || len(code) > AffiliateCodeMaxLength {
+ return false
+ }
+ for i := 0; i < len(code); i++ {
+ if !affiliateCodeValidChar[code[i]] {
+ return false
+ }
+ }
+ return true
+}
+
+type AffiliateSummary struct {
+ UserID int64 `json:"user_id"`
+ AffCode string `json:"aff_code"`
+ AffCodeCustom bool `json:"aff_code_custom"`
+ AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
+ InviterID *int64 `json:"inviter_id,omitempty"`
+ AffCount int `json:"aff_count"`
+ AffQuota float64 `json:"aff_quota"`
+ AffFrozenQuota float64 `json:"aff_frozen_quota"`
+ AffHistoryQuota float64 `json:"aff_history_quota"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
+type AffiliateInvitee struct {
+ UserID int64 `json:"user_id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+ CreatedAt *time.Time `json:"created_at,omitempty"`
+ TotalRebate float64 `json:"total_rebate"`
+}
+
+type AffiliateDetail struct {
+ UserID int64 `json:"user_id"`
+ AffCode string `json:"aff_code"`
+ InviterID *int64 `json:"inviter_id,omitempty"`
+ AffCount int `json:"aff_count"`
+ AffQuota float64 `json:"aff_quota"`
+ AffFrozenQuota float64 `json:"aff_frozen_quota"`
+ AffHistoryQuota float64 `json:"aff_history_quota"`
+ // EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
+ // 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
+ // 用于在用户的 /affiliate 页面直观展示「分享后能拿到多少」。
+ EffectiveRebateRatePercent float64 `json:"effective_rebate_rate_percent"`
+ Invitees []AffiliateInvitee `json:"invitees"`
+}
+
+type AffiliateRepository interface {
+ EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
+ GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
+ BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
+ AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error)
+ GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
+ ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
+ TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
+ ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
+
+ // 管理端:用户级专属配置
+ UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error
+ ResetUserAffCode(ctx context.Context, userID int64) (string, error)
+ SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
+ BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
+ ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
+}
+
+// AffiliateAdminFilter 列表筛选条件
+type AffiliateAdminFilter struct {
+ Search string
+ Page int
+ PageSize int
+}
+
+// AffiliateAdminEntry 专属用户列表条目
+type AffiliateAdminEntry struct {
+ UserID int64 `json:"user_id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+ AffCode string `json:"aff_code"`
+ AffCodeCustom bool `json:"aff_code_custom"`
+ AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
+ AffCount int `json:"aff_count"`
+}
+
+type AffiliateService struct {
+ repo AffiliateRepository
+ settingService *SettingService
+ authCacheInvalidator APIKeyAuthCacheInvalidator
+ billingCacheService *BillingCacheService
+}
+
+func NewAffiliateService(repo AffiliateRepository, settingService *SettingService, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService {
+ return &AffiliateService{
+ repo: repo,
+ settingService: settingService,
+ authCacheInvalidator: authCacheInvalidator,
+ billingCacheService: billingCacheService,
+ }
+}
+
+// IsEnabled reports whether the affiliate (邀请返利) feature is turned on.
+func (s *AffiliateService) IsEnabled(ctx context.Context) bool {
+ if s == nil || s.settingService == nil {
+ return AffiliateEnabledDefault
+ }
+ return s.settingService.IsAffiliateEnabled(ctx)
+}
+
+func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) {
+ if userID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
+ }
+ if s == nil || s.repo == nil {
+ return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ return s.repo.EnsureUserAffiliate(ctx, userID)
+}
+
+func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) {
+ // Lazy thaw: move any matured frozen quota to available before reading.
+ if s != nil && s.repo != nil {
+ // best-effort: thaw failure is non-fatal
+ _, _ = s.repo.ThawFrozenQuota(ctx, userID)
+ }
+
+ summary, err := s.EnsureUserAffiliate(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ invitees, err := s.listInvitees(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ return &AffiliateDetail{
+ UserID: summary.UserID,
+ AffCode: summary.AffCode,
+ InviterID: summary.InviterID,
+ AffCount: summary.AffCount,
+ AffQuota: summary.AffQuota,
+ AffFrozenQuota: summary.AffFrozenQuota,
+ AffHistoryQuota: summary.AffHistoryQuota,
+ EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
+ Invitees: invitees,
+ }, nil
+}
+
+func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, rawCode string) error {
+ code := strings.ToUpper(strings.TrimSpace(rawCode))
+ if code == "" {
+ return nil
+ }
+ if s == nil || s.repo == nil {
+ return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ // 总开关关闭时,注册阶段静默忽略 aff 参数(不报错,避免阻断注册流程)
+ if !s.IsEnabled(ctx) {
+ return nil
+ }
+ if !isValidAffiliateCodeFormat(code) {
+ return ErrAffiliateCodeInvalid
+ }
+
+ selfSummary, err := s.repo.EnsureUserAffiliate(ctx, userID)
+ if err != nil {
+ return err
+ }
+ if selfSummary.InviterID != nil {
+ return nil
+ }
+
+ inviterSummary, err := s.repo.GetAffiliateByCode(ctx, code)
+ if err != nil {
+ if errors.Is(err, ErrAffiliateProfileNotFound) {
+ return ErrAffiliateCodeInvalid
+ }
+ return err
+ }
+ if inviterSummary == nil || inviterSummary.UserID <= 0 || inviterSummary.UserID == userID {
+ return ErrAffiliateCodeInvalid
+ }
+
+ bound, err := s.repo.BindInviter(ctx, userID, inviterSummary.UserID)
+ if err != nil {
+ return err
+ }
+ if !bound {
+ return ErrAffiliateAlreadyBound
+ }
+ return nil
+}
+
+func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
+ if s == nil || s.repo == nil {
+ return 0, nil
+ }
+ if inviteeUserID <= 0 || baseRechargeAmount <= 0 || math.IsNaN(baseRechargeAmount) || math.IsInf(baseRechargeAmount, 0) {
+ return 0, nil
+ }
+ // 总开关关闭时,新充值不再产生返利
+ if !s.IsEnabled(ctx) {
+ return 0, nil
+ }
+
+ inviteeSummary, err := s.repo.EnsureUserAffiliate(ctx, inviteeUserID)
+ if err != nil {
+ return 0, err
+ }
+ if inviteeSummary.InviterID == nil || *inviteeSummary.InviterID <= 0 {
+ return 0, nil
+ }
+
+ // 加载邀请人 profile,优先使用专属比例(覆盖全局)
+ inviterSummary, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID)
+ if err != nil {
+ return 0, err
+ }
+ // 有效期检查:超过返利有效期后不再产生返利
+ if s.settingService != nil {
+ if durationDays := s.settingService.GetAffiliateRebateDurationDays(ctx); durationDays > 0 {
+ if time.Now().After(inviteeSummary.CreatedAt.AddDate(0, 0, durationDays)) {
+ return 0, nil
+ }
+ }
+ }
+
+ rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary)
+ rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
+ if rebate <= 0 {
+ return 0, nil
+ }
+
+ // 单人上限检查:精确截断到剩余额度
+ if s.settingService != nil {
+ if perInviteeCap := s.settingService.GetAffiliateRebatePerInviteeCap(ctx); perInviteeCap > 0 {
+ existing, err := s.repo.GetAccruedRebateFromInvitee(ctx, *inviteeSummary.InviterID, inviteeUserID)
+ if err != nil {
+ return 0, err
+ }
+ if existing >= perInviteeCap {
+ return 0, nil
+ }
+ if remaining := perInviteeCap - existing; rebate > remaining {
+ rebate = roundTo(remaining, 8)
+ }
+ }
+ }
+
+ var freezeHours int
+ if s.settingService != nil {
+ freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
+ }
+
+ applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours)
+ if err != nil {
+ return 0, err
+ }
+ if !applied {
+ return 0, nil
+ }
+ return rebate, nil
+}
+
+// resolveRebateRatePercent returns the inviter's exclusive rate when set,
+// otherwise the global setting value (clamped to [Min, Max]).
+func (s *AffiliateService) resolveRebateRatePercent(ctx context.Context, inviter *AffiliateSummary) float64 {
+ if inviter != nil && inviter.AffRebateRatePercent != nil {
+ v := *inviter.AffRebateRatePercent
+ if math.IsNaN(v) || math.IsInf(v, 0) {
+ return s.globalRebateRatePercent(ctx)
+ }
+ return clampAffiliateRebateRate(v)
+ }
+ return s.globalRebateRatePercent(ctx)
+}
+
+// globalRebateRatePercent reads the system-wide rebate rate via SettingService,
+// returning the documented default when SettingService is unavailable.
+func (s *AffiliateService) globalRebateRatePercent(ctx context.Context) float64 {
+ if s == nil || s.settingService == nil {
+ return AffiliateRebateRateDefault
+ }
+ return s.settingService.GetAffiliateRebateRatePercent(ctx)
+}
+
+func (s *AffiliateService) TransferAffiliateQuota(ctx context.Context, userID int64) (float64, float64, error) {
+ if s == nil || s.repo == nil {
+ return 0, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+
+ transferred, balance, err := s.repo.TransferQuotaToBalance(ctx, userID)
+ if err != nil {
+ return 0, 0, err
+ }
+ if transferred > 0 {
+ s.invalidateAffiliateCaches(ctx, userID)
+ }
+ return transferred, balance, nil
+}
+
+func (s *AffiliateService) listInvitees(ctx context.Context, inviterID int64) ([]AffiliateInvitee, error) {
+ if s == nil || s.repo == nil {
+ return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ invitees, err := s.repo.ListInvitees(ctx, inviterID, affiliateInviteesLimit)
+ if err != nil {
+ return nil, err
+ }
+ for i := range invitees {
+ invitees[i].Email = maskEmail(invitees[i].Email)
+ }
+ return invitees, nil
+}
+
+func roundTo(v float64, scale int) float64 {
+ factor := math.Pow10(scale)
+ return math.Round(v*factor) / factor
+}
+
+func maskEmail(email string) string {
+ email = strings.TrimSpace(email)
+ if email == "" {
+ return ""
+ }
+ at := strings.Index(email, "@")
+ if at <= 0 || at >= len(email)-1 {
+ return "***"
+ }
+
+ local := email[:at]
+ domain := email[at+1:]
+ dot := strings.LastIndex(domain, ".")
+
+ maskedLocal := maskSegment(local)
+ if dot <= 0 || dot >= len(domain)-1 {
+ return maskedLocal + "@" + maskSegment(domain)
+ }
+
+ domainName := domain[:dot]
+ tld := domain[dot:]
+ return maskedLocal + "@" + maskSegment(domainName) + tld
+}
+
+func maskSegment(s string) string {
+ r := []rune(s)
+ if len(r) == 0 {
+ return "***"
+ }
+ if len(r) == 1 {
+ return string(r[0]) + "***"
+ }
+ return string(r[0]) + "***"
+}
+
+func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID int64) {
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ }
+ if s.billingCacheService != nil {
+ if err := s.billingCacheService.InvalidateUserBalance(ctx, userID); err != nil {
+ logger.LegacyPrintf("service.affiliate", "[Affiliate] Failed to invalidate billing cache for user %d: %v", userID, err)
+ }
+ }
+}
+
+// =========================
+// Admin: 专属配置管理
+// =========================
+
+// validateExclusiveRate ensures a per-user override is finite and within
+// [Min, Max]. nil is always valid (means "clear / fall back to global").
+func validateExclusiveRate(ratePercent *float64) error {
+ if ratePercent == nil {
+ return nil
+ }
+ v := *ratePercent
+ if math.IsNaN(v) || math.IsInf(v, 0) {
+ return infraerrors.BadRequest("INVALID_RATE", "invalid rebate rate")
+ }
+ if v < AffiliateRebateRateMin || v > AffiliateRebateRateMax {
+ return infraerrors.BadRequest("INVALID_RATE", "rebate rate out of range")
+ }
+ return nil
+}
+
+// AdminUpdateUserAffCode 管理员改写用户的邀请码(专属邀请码)。
+func (s *AffiliateService) AdminUpdateUserAffCode(ctx context.Context, userID int64, rawCode string) error {
+ if s == nil || s.repo == nil {
+ return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ code := strings.ToUpper(strings.TrimSpace(rawCode))
+ if !isValidAffiliateCodeFormat(code) {
+ return ErrAffiliateCodeInvalid
+ }
+ return s.repo.UpdateUserAffCode(ctx, userID, code)
+}
+
+// AdminResetUserAffCode 重置用户邀请码为系统随机码。
+func (s *AffiliateService) AdminResetUserAffCode(ctx context.Context, userID int64) (string, error) {
+ if s == nil || s.repo == nil {
+ return "", infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ return s.repo.ResetUserAffCode(ctx, userID)
+}
+
+// AdminSetUserRebateRate 设置/清除用户专属返利比例。ratePercent==nil 表示清除。
+func (s *AffiliateService) AdminSetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
+ if s == nil || s.repo == nil {
+ return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ if err := validateExclusiveRate(ratePercent); err != nil {
+ return err
+ }
+ return s.repo.SetUserRebateRate(ctx, userID, ratePercent)
+}
+
+// AdminBatchSetUserRebateRate 批量设置/清除用户专属返利比例。
+func (s *AffiliateService) AdminBatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
+ if s == nil || s.repo == nil {
+ return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ if err := validateExclusiveRate(ratePercent); err != nil {
+ return err
+ }
+ cleaned := make([]int64, 0, len(userIDs))
+ for _, uid := range userIDs {
+ if uid > 0 {
+ cleaned = append(cleaned, uid)
+ }
+ }
+ if len(cleaned) == 0 {
+ return nil
+ }
+ return s.repo.BatchSetUserRebateRate(ctx, cleaned, ratePercent)
+}
+
+// AdminListCustomUsers 列出有专属配置的用户。
+func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) {
+ if s == nil || s.repo == nil {
+ return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ return s.repo.ListUsersWithCustomSettings(ctx, filter)
+}
diff --git a/backend/internal/service/affiliate_service_test.go b/backend/internal/service/affiliate_service_test.go
new file mode 100644
index 00000000..c02a4dd7
--- /dev/null
+++ b/backend/internal/service/affiliate_service_test.go
@@ -0,0 +1,131 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "math"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// TestResolveRebateRatePercent_PerUserOverride verifies that per-inviter
+// AffRebateRatePercent overrides the global rate, that NULL falls back to the
+// global rate, and that out-of-range exclusive rates are clamped silently.
+//
+// SettingService is left nil here so globalRebateRatePercent returns the
+// documented default (AffiliateRebateRateDefault = 20%) — this exercises the
+// fallback path without spinning up a settings stub.
+func TestResolveRebateRatePercent_PerUserOverride(t *testing.T) {
+ t.Parallel()
+ svc := &AffiliateService{}
+
+ // nil exclusive rate → falls back to global default (20%)
+ require.InDelta(t, AffiliateRebateRateDefault,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{}), 1e-9)
+
+ // exclusive rate set → overrides global
+ rate := 50.0
+ require.InDelta(t, 50.0,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &rate}), 1e-9)
+
+ // exclusive rate 0 → returns 0 (no rebate, intentional)
+ zero := 0.0
+ require.InDelta(t, 0.0,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &zero}), 1e-9)
+
+ // exclusive rate above max → clamped to Max
+ tooHigh := 250.0
+ require.InDelta(t, AffiliateRebateRateMax,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooHigh}), 1e-9)
+
+ // exclusive rate below min → clamped to Min
+ tooLow := -5.0
+ require.InDelta(t, AffiliateRebateRateMin,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooLow}), 1e-9)
+}
+
+// TestIsEnabled_NilSettingServiceReturnsDefault verifies that IsEnabled
+// safely handles a nil settingService dependency by returning the default
+// (off). This protects callers from nil-pointer crashes in misconfigured
+// environments.
+func TestIsEnabled_NilSettingServiceReturnsDefault(t *testing.T) {
+ t.Parallel()
+ svc := &AffiliateService{}
+ require.False(t, svc.IsEnabled(context.Background()))
+ require.Equal(t, AffiliateEnabledDefault, svc.IsEnabled(context.Background()))
+}
+
+// TestValidateExclusiveRate_BoundaryAndInvalid covers the validator used by
+// admin-facing rate setters: nil is always valid (clear), in-range values
+// are accepted, NaN/Inf and out-of-range values produce a typed BadRequest.
+func TestValidateExclusiveRate_BoundaryAndInvalid(t *testing.T) {
+ t.Parallel()
+ require.NoError(t, validateExclusiveRate(nil))
+
+ for _, v := range []float64{0, 0.01, 50, 99.99, 100} {
+ v := v
+ require.NoError(t, validateExclusiveRate(&v), "value %v should be valid", v)
+ }
+
+ for _, v := range []float64{-0.01, 100.01, -100, 200} {
+ v := v
+ require.Error(t, validateExclusiveRate(&v), "value %v should be rejected", v)
+ }
+
+ nan := math.NaN()
+ require.Error(t, validateExclusiveRate(&nan))
+ posInf := math.Inf(1)
+ require.Error(t, validateExclusiveRate(&posInf))
+ negInf := math.Inf(-1)
+ require.Error(t, validateExclusiveRate(&negInf))
+}
+
+func TestMaskEmail(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, "a***@g***.com", maskEmail("alice@gmail.com"))
+ require.Equal(t, "x***@d***", maskEmail("x@domain"))
+ require.Equal(t, "", maskEmail(""))
+}
+
+func TestIsValidAffiliateCodeFormat(t *testing.T) {
+ t.Parallel()
+
+ // 邀请码格式校验同时服务于:
+ // 1) 系统自动生成的 12 位随机码(A-Z 去 I/O,2-9 去 0/1)
+ // 2) 管理员设置的自定义专属码(如 "VIP2026"、"NEW_USER-1")
+ // 因此校验放宽到 [A-Z0-9_-]{4,32}(要求调用方先 ToUpper)。
+ cases := []struct {
+ name string
+ in string
+ want bool
+ }{
+ {"valid canonical 12-char", "ABCDEFGHJKLM", true},
+ {"valid all digits 2-9", "234567892345", true},
+ {"valid mixed", "A2B3C4D5E6F7", true},
+ {"valid admin custom short", "VIP1", true},
+ {"valid admin custom with hyphen", "NEW-USER", true},
+ {"valid admin custom with underscore", "VIP_2026", true},
+ {"valid 32-char max", "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345", true},
+ // Previously-excluded chars (I/O/0/1) are now allowed since admins may use them.
+ {"letter I now allowed", "IBCDEFGHJKLM", true},
+ {"letter O now allowed", "OBCDEFGHJKLM", true},
+ {"digit 0 now allowed", "0BCDEFGHJKLM", true},
+ {"digit 1 now allowed", "1BCDEFGHJKLM", true},
+ {"too short (3 chars)", "ABC", false},
+ {"too long (33 chars)", "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456", false},
+ {"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false},
+ {"empty", "", false},
+ {"utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // bytes out of charset
+ {"ascii punctuation .", "ABCDEFGHJK.M", false},
+ {"whitespace", "ABCDEFGHJK M", false},
+ }
+ for _, tc := range cases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tc.want, isValidAffiliateCodeFormat(tc.in))
+ })
+ }
+}
diff --git a/backend/internal/service/announcement.go b/backend/internal/service/announcement.go
index 25c66eb4..02741d37 100644
--- a/backend/internal/service/announcement.go
+++ b/backend/internal/service/announcement.go
@@ -5,6 +5,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -34,8 +35,23 @@ const (
)
var (
- ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
- ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
+ ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
+ ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
+ ErrAnnouncementNilInput = infraerrors.BadRequest("ANNOUNCEMENT_INPUT_REQUIRED", "announcement input is required")
+ ErrAnnouncementInvalidTitle = infraerrors.BadRequest("ANNOUNCEMENT_TITLE_INVALID", "announcement title is invalid")
+ ErrAnnouncementContentRequired = infraerrors.BadRequest(
+ "ANNOUNCEMENT_CONTENT_REQUIRED",
+ "announcement content is required",
+ )
+ ErrAnnouncementInvalidStatus = infraerrors.BadRequest("ANNOUNCEMENT_STATUS_INVALID", "announcement status is invalid")
+ ErrAnnouncementInvalidNotifyMode = infraerrors.BadRequest(
+ "ANNOUNCEMENT_NOTIFY_MODE_INVALID",
+ "announcement notify_mode is invalid",
+ )
+ ErrAnnouncementInvalidSchedule = infraerrors.BadRequest(
+ "ANNOUNCEMENT_TIME_RANGE_INVALID",
+ "starts_at must be before ends_at",
+ )
)
type AnnouncementTargeting = domain.AnnouncementTargeting
diff --git a/backend/internal/service/announcement_service.go b/backend/internal/service/announcement_service.go
index c0a0681a..12479041 100644
--- a/backend/internal/service/announcement_service.go
+++ b/backend/internal/service/announcement_service.go
@@ -70,16 +70,16 @@ type AnnouncementUserReadStatus struct {
func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) {
if input == nil {
- return nil, fmt.Errorf("create announcement: nil input")
+ return nil, ErrAnnouncementNilInput
}
title := strings.TrimSpace(input.Title)
content := strings.TrimSpace(input.Content)
if title == "" || len(title) > 200 {
- return nil, fmt.Errorf("create announcement: invalid title")
+ return nil, ErrAnnouncementInvalidTitle
}
if content == "" {
- return nil, fmt.Errorf("create announcement: content is required")
+ return nil, ErrAnnouncementContentRequired
}
status := strings.TrimSpace(input.Status)
@@ -87,7 +87,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
status = AnnouncementStatusDraft
}
if !isValidAnnouncementStatus(status) {
- return nil, fmt.Errorf("create announcement: invalid status")
+ return nil, ErrAnnouncementInvalidStatus
}
targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate()
@@ -100,12 +100,12 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
notifyMode = AnnouncementNotifyModeSilent
}
if !isValidAnnouncementNotifyMode(notifyMode) {
- return nil, fmt.Errorf("create announcement: invalid notify_mode")
+ return nil, ErrAnnouncementInvalidNotifyMode
}
if input.StartsAt != nil && input.EndsAt != nil {
if !input.StartsAt.Before(*input.EndsAt) {
- return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
+ return nil, ErrAnnouncementInvalidSchedule
}
}
@@ -131,7 +131,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) {
if input == nil {
- return nil, fmt.Errorf("update announcement: nil input")
+ return nil, ErrAnnouncementNilInput
}
a, err := s.announcementRepo.GetByID(ctx, id)
@@ -142,21 +142,21 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if input.Title != nil {
title := strings.TrimSpace(*input.Title)
if title == "" || len(title) > 200 {
- return nil, fmt.Errorf("update announcement: invalid title")
+ return nil, ErrAnnouncementInvalidTitle
}
a.Title = title
}
if input.Content != nil {
content := strings.TrimSpace(*input.Content)
if content == "" {
- return nil, fmt.Errorf("update announcement: content is required")
+ return nil, ErrAnnouncementContentRequired
}
a.Content = content
}
if input.Status != nil {
status := strings.TrimSpace(*input.Status)
if !isValidAnnouncementStatus(status) {
- return nil, fmt.Errorf("update announcement: invalid status")
+ return nil, ErrAnnouncementInvalidStatus
}
a.Status = status
}
@@ -164,7 +164,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if input.NotifyMode != nil {
notifyMode := strings.TrimSpace(*input.NotifyMode)
if !isValidAnnouncementNotifyMode(notifyMode) {
- return nil, fmt.Errorf("update announcement: invalid notify_mode")
+ return nil, ErrAnnouncementInvalidNotifyMode
}
a.NotifyMode = notifyMode
}
@@ -186,7 +186,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if a.StartsAt != nil && a.EndsAt != nil {
if !a.StartsAt.Before(*a.EndsAt) {
- return nil, fmt.Errorf("update announcement: starts_at must be before ends_at")
+ return nil, ErrAnnouncementInvalidSchedule
}
}
diff --git a/backend/internal/service/announcement_service_test.go b/backend/internal/service/announcement_service_test.go
new file mode 100644
index 00000000..77fb9896
--- /dev/null
+++ b/backend/internal/service/announcement_service_test.go
@@ -0,0 +1,81 @@
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type announcementRepoStub struct {
+ item *Announcement
+}
+
+func (s *announcementRepoStub) Create(_ context.Context, a *Announcement) error {
+ s.item = a
+ return nil
+}
+
+func (s *announcementRepoStub) GetByID(_ context.Context, _ int64) (*Announcement, error) {
+ if s.item == nil {
+ return nil, ErrAnnouncementNotFound
+ }
+ return s.item, nil
+}
+
+func (s *announcementRepoStub) Update(_ context.Context, a *Announcement) error {
+ s.item = a
+ return nil
+}
+
+func (*announcementRepoStub) Delete(context.Context, int64) error {
+ return nil
+}
+
+func (*announcementRepoStub) List(context.Context, pagination.PaginationParams, AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+
+func (*announcementRepoStub) ListActive(context.Context, time.Time) ([]Announcement, error) {
+ return nil, nil
+}
+
+func TestAnnouncementServiceCreateRejectsEqualStartEndTimes(t *testing.T) {
+ repo := &announcementRepoStub{}
+ svc := NewAnnouncementService(repo, nil, nil, nil)
+ now := time.Unix(1776790020, 0)
+
+ _, err := svc.Create(context.Background(), &CreateAnnouncementInput{
+ Title: "公告",
+ Content: "内容",
+ Status: AnnouncementStatusActive,
+ NotifyMode: AnnouncementNotifyModePopup,
+ StartsAt: &now,
+ EndsAt: &now,
+ })
+ require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule)
+}
+
+func TestAnnouncementServiceUpdateRejectsEqualStartEndTimes(t *testing.T) {
+ repo := &announcementRepoStub{
+ item: &Announcement{
+ ID: 1,
+ Title: "公告",
+ Content: "内容",
+ Status: AnnouncementStatusActive,
+ NotifyMode: AnnouncementNotifyModePopup,
+ },
+ }
+ svc := NewAnnouncementService(repo, nil, nil, nil)
+ now := time.Unix(1776790020, 0)
+ startsAt := &now
+ endsAt := &now
+
+ _, err := svc.Update(context.Background(), 1, &UpdateAnnouncementInput{
+ StartsAt: &startsAt,
+ EndsAt: &endsAt,
+ })
+ require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule)
+}
diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go
index b1660ea7..1a1c78b8 100644
--- a/backend/internal/service/api_key_auth_cache.go
+++ b/backend/internal/service/api_key_auth_cache.go
@@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct {
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
TotalRecharged float64 `json:"total_recharged"`
+
+ // RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。
+ RPMLimit int `json:"rpm_limit"`
+
+ // UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。
+ // nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。
+ UserGroupRPMOverride *int `json:"user_group_rpm_override,omitempty"`
}
// APIKeyAuthGroupSnapshot 分组快照
@@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct {
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
+
+ // RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。
+ RPMLimit int `json:"rpm_limit"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go
index 2bd9a091..974ea66e 100644
--- a/backend/internal/service/api_key_auth_cache_impl.go
+++ b/backend/internal/service/api_key_auth_cache_impl.go
@@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
-const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
+const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
type apiKeyAuthCacheConfig struct {
l1Size int
@@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
- snapshot := s.snapshotFromAPIKey(apiKey)
+ snapshot := s.snapshotFromAPIKey(ctx, apiKey)
if snapshot == nil {
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
}
@@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
}
-func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
+func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot {
if apiKey == nil || apiKey.User == nil {
return nil
}
@@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
TotalRecharged: apiKey.User.TotalRecharged,
+ RPMLimit: apiKey.User.RPMLimit,
},
}
+
+ // 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。
+ if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil {
+ override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID)
+ if err == nil && override != nil {
+ snapshot.User.UserGroupRPMOverride = override
+ }
+ // 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询
+ }
if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID,
@@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
+ RPMLimit: apiKey.Group.RPMLimit,
}
}
return snapshot
@@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
TotalRecharged: snapshot.User.TotalRecharged,
+ RPMLimit: snapshot.User.RPMLimit,
+ UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride,
},
}
if snapshot.Group != nil {
@@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
+ RPMLimit: snapshot.Group.RPMLimit,
}
}
s.compileAPIKeyIPRules(apiKey)
diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go
index 3c2f7dbb..8cb1b8c4 100644
--- a/backend/internal/service/api_key_service_cache_test.go
+++ b/backend/internal/service/api_key_service_cache_test.go
@@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t
},
}
- snapshot := svc.snapshotFromAPIKey(apiKey)
+ snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey)
roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
require.NotNil(t, roundTrip)
diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go
new file mode 100644
index 00000000..78f1185d
--- /dev/null
+++ b/backend/internal/service/auth_email_binding.go
@@ -0,0 +1,319 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/mail"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+)
+
+// BindEmailIdentity verifies and binds a local email/password identity to the
+// current user, or replaces the existing bound primary email.
+func (s *AuthService) BindEmailIdentity(
+ ctx context.Context,
+ userID int64,
+ email string,
+ verifyCode string,
+ password string,
+) (*User, error) {
+ if s == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ normalizedEmail, err := normalizeEmailForIdentityBinding(email)
+ if err != nil {
+ return nil, err
+ }
+ if isReservedEmail(normalizedEmail) {
+ return nil, ErrEmailReserved
+ }
+ if strings.TrimSpace(password) == "" {
+ return nil, ErrPasswordRequired
+ }
+ if err := s.VerifyOAuthEmailCode(ctx, normalizedEmail, verifyCode); err != nil {
+ return nil, err
+ }
+
+ currentUser, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
+ if firstRealEmailBind && len(password) < 6 {
+ return nil, infraerrors.BadRequest("PASSWORD_TOO_SHORT", "password must be at least 6 characters")
+ }
+ if !firstRealEmailBind && !s.CheckPassword(password, currentUser.PasswordHash) {
+ return nil, ErrPasswordIncorrect
+ }
+
+ existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
+ switch {
+ case err == nil && existingUser != nil && existingUser.ID != userID:
+ return nil, ErrEmailExists
+ case err != nil && !errors.Is(err, ErrUserNotFound):
+ return nil, ErrServiceUnavailable
+ }
+
+ hashedPassword, err := s.HashPassword(password)
+ if err != nil {
+ return nil, fmt.Errorf("hash password: %w", err)
+ }
+
+ if s.entClient != nil {
+ if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil {
+ return nil, err
+ }
+ s.revokeEmailIdentitySessions(ctx, userID)
+ return currentUser, nil
+ }
+
+ currentUser.Email = normalizedEmail
+ currentUser.PasswordHash = hashedPassword
+ if err := s.userRepo.Update(ctx, currentUser); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return nil, ErrEmailExists
+ }
+ return nil, ErrServiceUnavailable
+ }
+
+ if firstRealEmailBind {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, userID, "email"); err != nil {
+ return nil, fmt.Errorf("apply email first bind defaults: %w", err)
+ }
+ }
+
+ s.revokeEmailIdentitySessions(ctx, userID)
+ return currentUser, nil
+}
+
+// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows.
+func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error {
+ if s == nil {
+ return ErrServiceUnavailable
+ }
+
+ normalizedEmail, err := normalizeEmailForIdentityBinding(email)
+ if err != nil {
+ return err
+ }
+ if isReservedEmail(normalizedEmail) {
+ return ErrEmailReserved
+ }
+ if s.emailService == nil {
+ return ErrServiceUnavailable
+ }
+ if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
+ if errors.Is(err, ErrUserNotFound) {
+ return ErrUserNotFound
+ }
+ return ErrServiceUnavailable
+ }
+
+ existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
+ switch {
+ case err == nil && existingUser != nil && existingUser.ID != userID:
+ return ErrEmailExists
+ case err != nil && !errors.Is(err, ErrUserNotFound):
+ return ErrServiceUnavailable
+ }
+
+ siteName := "Sub2API"
+ if s.settingService != nil {
+ siteName = s.settingService.GetSiteName(ctx)
+ }
+ return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName)
+}
+
+func normalizeEmailForIdentityBinding(email string) (string, error) {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" || len(normalized) > 255 {
+ return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
+ }
+ if _, err := mail.ParseAddress(normalized); err != nil {
+ return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
+ }
+ return normalized, nil
+}
+
+func hasBindableEmailIdentitySubject(email string) bool {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ return normalized != "" && !isReservedEmail(normalized)
+}
+
+func (s *AuthService) updateBoundEmailIdentityTx(
+ ctx context.Context,
+ currentUser *User,
+ email string,
+ hashedPassword string,
+ applyFirstBindDefaults bool,
+) error {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return s.updateBoundEmailIdentityWithClient(ctx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults)
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return ErrServiceUnavailable
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := s.updateBoundEmailIdentityWithClient(txCtx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults); err != nil {
+ return err
+ }
+ if err := tx.Commit(); err != nil {
+ return ErrServiceUnavailable
+ }
+ return nil
+}
+
+func (s *AuthService) updateBoundEmailIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ currentUser *User,
+ email string,
+ hashedPassword string,
+ applyFirstBindDefaults bool,
+) error {
+ if client == nil || currentUser == nil || currentUser.ID <= 0 {
+ return ErrServiceUnavailable
+ }
+
+ oldEmail := currentUser.Email
+ if _, err := client.User.UpdateOneID(currentUser.ID).
+ SetEmail(email).
+ SetPasswordHash(hashedPassword).
+ Save(ctx); err != nil {
+ if dbent.IsConstraintError(err) {
+ return ErrEmailExists
+ }
+ return ErrServiceUnavailable
+ }
+
+ if err := replaceBoundEmailAuthIdentityWithClient(ctx, client, currentUser.ID, oldEmail, email, "auth_service_email_bind"); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return ErrEmailExists
+ }
+ return ErrServiceUnavailable
+ }
+
+ if applyFirstBindDefaults {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
+ return fmt.Errorf("apply email first bind defaults: %w", err)
+ }
+ }
+
+ updatedUser, err := client.User.Get(ctx, currentUser.ID)
+ if err != nil {
+ return ErrServiceUnavailable
+ }
+ currentUser.Email = updatedUser.Email
+ currentUser.PasswordHash = updatedUser.PasswordHash
+ currentUser.Balance = updatedUser.Balance
+ currentUser.Concurrency = updatedUser.Concurrency
+ currentUser.UpdatedAt = updatedUser.UpdatedAt
+ return nil
+}
+
+func (s *AuthService) revokeEmailIdentitySessions(ctx context.Context, userID int64) {
+ if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after email identity bind for user %d: %v", userID, err)
+ }
+}
+
+func replaceBoundEmailAuthIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ userID int64,
+ oldEmail string,
+ newEmail string,
+ source string,
+) error {
+ newSubject := normalizeBoundEmailAuthIdentitySubject(newEmail)
+ if err := ensureBoundEmailAuthIdentityWithClient(ctx, client, userID, newSubject, source); err != nil {
+ return err
+ }
+
+ oldSubject := normalizeBoundEmailAuthIdentitySubject(oldEmail)
+ if oldSubject == "" || oldSubject == newSubject {
+ return nil
+ }
+
+ _, err := client.AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(oldSubject),
+ ).
+ Exec(ctx)
+ return err
+}
+
+func ensureBoundEmailAuthIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ userID int64,
+ subject string,
+ source string,
+) error {
+ if client == nil || userID <= 0 || subject == "" {
+ return nil
+ }
+
+ if strings.TrimSpace(source) == "" {
+ source = "auth_service_email_bind"
+ }
+
+ if err := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(subject).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": strings.TrimSpace(source)}).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ if !isSQLNoRowsError(err) {
+ return err
+ }
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(subject),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity.UserID != userID {
+ return ErrEmailExists
+ }
+ return nil
+}
+
+func normalizeBoundEmailAuthIdentitySubject(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" || isReservedEmail(normalized) {
+ return ""
+ }
+ return normalized
+}
diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go
new file mode 100644
index 00000000..9815f31b
--- /dev/null
+++ b/backend/internal/service/auth_oauth_email_flow.go
@@ -0,0 +1,387 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/mail"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/redeemcode"
+)
+
+func normalizeOAuthSignupSource(signupSource string) string {
+ signupSource = strings.TrimSpace(strings.ToLower(signupSource))
+ switch signupSource {
+ case "", "email":
+ return "email"
+ case "linuxdo", "wechat", "oidc":
+ return signupSource
+ default:
+ return "email"
+ }
+}
+
+// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
+// account-creation flows without relying on the public registration gate.
+func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" {
+ return nil, ErrEmailVerifyRequired
+ }
+ if _, err := mail.ParseAddress(email); err != nil {
+ return nil, ErrEmailVerifyRequired
+ }
+ if isReservedEmail(email) {
+ return nil, ErrEmailReserved
+ }
+ if s == nil || s.emailService == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ siteName := "Sub2API"
+ if s.settingService != nil {
+ siteName = s.settingService.GetSiteName(ctx)
+ }
+ if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil {
+ return nil, err
+ }
+ return &SendVerifyCodeResult{
+ Countdown: int(verifyCodeCooldown / time.Second),
+ }, nil
+}
+
+func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
+ if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
+ return nil, nil
+ }
+ if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ invitationCode = strings.TrimSpace(invitationCode)
+ if invitationCode == "" {
+ return nil, ErrInvitationCodeRequired
+ }
+
+ redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
+ if err != nil {
+ return nil, ErrInvitationCodeInvalid
+ }
+ if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
+ return nil, ErrInvitationCodeInvalid
+ }
+ return redeemCode, nil
+}
+
+// VerifyOAuthEmailCode verifies the locally entered email verification code for
+// third-party signup and binding flows. This is intentionally independent from
+// the global registration email verification toggle.
+func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error {
+ email = strings.TrimSpace(strings.ToLower(email))
+ verifyCode = strings.TrimSpace(verifyCode)
+
+ if email == "" {
+ return ErrEmailVerifyRequired
+ }
+ if verifyCode == "" {
+ return ErrEmailVerifyRequired
+ }
+ if s == nil || s.emailService == nil {
+ return ErrServiceUnavailable
+ }
+ return s.emailService.VerifyCode(ctx, email, verifyCode)
+}
+
+// RegisterOAuthEmailAccount creates a local account from a third-party first
+// login after the user has verified a local email address.
+func (s *AuthService) RegisterOAuthEmailAccount(
+ ctx context.Context,
+ email string,
+ password string,
+ verifyCode string,
+ invitationCode string,
+ signupSource string,
+) (*TokenPair, *User, error) {
+ if s == nil {
+ return nil, nil, ErrServiceUnavailable
+ }
+ if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
+ return nil, nil, ErrRegDisabled
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if isReservedEmail(email) {
+ return nil, nil, ErrEmailReserved
+ }
+ if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
+ return nil, nil, err
+ }
+ if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil {
+ return nil, nil, err
+ }
+
+ if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
+ return nil, nil, err
+ }
+
+ existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
+ if err != nil {
+ return nil, nil, ErrServiceUnavailable
+ }
+ if existsEmail {
+ return nil, nil, ErrEmailExists
+ }
+
+ hashedPassword, err := s.HashPassword(password)
+ if err != nil {
+ return nil, nil, fmt.Errorf("hash password: %w", err)
+ }
+
+ signupSource = normalizeOAuthSignupSource(signupSource)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+
+ user := &User{
+ Email: email,
+ PasswordHash: hashedPassword,
+ Role: RoleUser,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
+ Status: StatusActive,
+ SignupSource: signupSource,
+ }
+
+ if err := s.userRepo.Create(ctx, user); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return nil, nil, ErrEmailExists
+ }
+ return nil, nil, ErrServiceUnavailable
+ }
+
+ tokenPair, err := s.GenerateTokenPair(ctx, user, "")
+ if err != nil {
+ _ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
+ return nil, nil, fmt.Errorf("generate token pair: %w", err)
+ }
+ return tokenPair, user, nil
+}
+
+// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
+// only after the pending OAuth flow has fully reached its last reversible step.
+func (s *AuthService) FinalizeOAuthEmailAccount(
+ ctx context.Context,
+ user *User,
+ invitationCode string,
+ signupSource string,
+ affiliateCode string,
+) error {
+ if s == nil || user == nil || user.ID <= 0 {
+ return ErrServiceUnavailable
+ }
+
+ signupSource = normalizeOAuthSignupSource(signupSource)
+ invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
+ if err != nil {
+ return err
+ }
+ if invitationRedeemCode != nil {
+ if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
+ return ErrInvitationCodeInvalid
+ }
+ }
+
+ s.updateOAuthSignupSource(ctx, user.ID, signupSource)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
+ return nil
+}
+
+// RollbackOAuthEmailAccountCreation removes a partially-created local account
+// and restores any invitation code already consumed by that account.
+func (s *AuthService) RollbackOAuthEmailAccountCreation(ctx context.Context, userID int64, invitationCode string) error {
+ if s == nil || s.userRepo == nil || userID <= 0 {
+ return ErrServiceUnavailable
+ }
+ if err := s.restoreOAuthRegistrationInvitation(ctx, invitationCode, userID); err != nil {
+ return err
+ }
+ if err := s.userRepo.Delete(ctx, userID); err != nil {
+ return fmt.Errorf("delete created oauth user: %w", err)
+ }
+ return nil
+}
+
+func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, invitationCode string, userID int64) error {
+ if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
+ return nil
+ }
+ if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
+ return ErrServiceUnavailable
+ }
+
+ invitationCode = strings.TrimSpace(invitationCode)
+ if invitationCode == "" || userID <= 0 {
+ return nil
+ }
+
+ redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
+ if err != nil {
+ if errors.Is(err, ErrRedeemCodeNotFound) {
+ return nil
+ }
+ return fmt.Errorf("load invitation code: %w", err)
+ }
+ if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUsed || redeemCode.UsedBy == nil || *redeemCode.UsedBy != userID {
+ return nil
+ }
+
+ redeemCode.Status = StatusUnused
+ redeemCode.UsedBy = nil
+ redeemCode.UsedAt = nil
+ if err := s.updateOAuthRegistrationInvitation(ctx, redeemCode); err != nil {
+ return fmt.Errorf("restore invitation code: %w", err)
+ }
+ return nil
+}
+
+func (s *AuthService) oauthEmailFlowClient(ctx context.Context) *dbent.Client {
+ if s == nil || s.entClient == nil {
+ return nil
+ }
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return tx.Client()
+ }
+ return s.entClient
+}
+
+func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ entity, err := client.RedeemCode.Query().Where(redeemcode.CodeEQ(invitationCode)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrRedeemCodeNotFound
+ }
+ return nil, err
+ }
+ return &RedeemCode{
+ ID: entity.ID,
+ Code: entity.Code,
+ Type: entity.Type,
+ Value: entity.Value,
+ Status: entity.Status,
+ UsedBy: entity.UsedBy,
+ UsedAt: entity.UsedAt,
+ Notes: oauthEmailFlowStringValue(entity.Notes),
+ CreatedAt: entity.CreatedAt,
+ GroupID: entity.GroupID,
+ ValidityDays: entity.ValidityDays,
+ }, nil
+ }
+ return s.redeemRepo.GetByCode(ctx, invitationCode)
+}
+
+func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error {
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ affected, err := client.RedeemCode.Update().
+ Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)).
+ SetStatus(StatusUsed).
+ SetUsedBy(userID).
+ SetUsedAt(time.Now().UTC()).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return ErrRedeemCodeUsed
+ }
+ return nil
+ }
+ return s.redeemRepo.Use(ctx, invitationID, userID)
+}
+
+func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, code *RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ update := client.RedeemCode.UpdateOneID(code.ID).
+ SetCode(code.Code).
+ SetType(code.Type).
+ SetValue(code.Value).
+ SetStatus(code.Status).
+ SetNotes(code.Notes).
+ SetValidityDays(code.ValidityDays)
+ if code.UsedBy != nil {
+ update = update.SetUsedBy(*code.UsedBy)
+ } else {
+ update = update.ClearUsedBy()
+ }
+ if code.UsedAt != nil {
+ update = update.SetUsedAt(*code.UsedAt)
+ } else {
+ update = update.ClearUsedAt()
+ }
+ if code.GroupID != nil {
+ update = update.SetGroupID(*code.GroupID)
+ } else {
+ update = update.ClearGroupID()
+ }
+ _, err := update.Save(ctx)
+ return err
+ }
+ return s.redeemRepo.Update(ctx, code)
+}
+
+func (s *AuthService) updateOAuthSignupSource(ctx context.Context, userID int64, signupSource string) {
+ client := s.oauthEmailFlowClient(ctx)
+ if client == nil || userID <= 0 || strings.TrimSpace(signupSource) == "" {
+ return
+ }
+ _ = client.User.UpdateOneID(userID).SetSignupSource(signupSource).Exec(ctx)
+}
+
+func oauthEmailFlowStringValue(value *string) string {
+ if value == nil {
+ return ""
+ }
+ return *value
+}
+
+// ValidatePasswordCredentials checks the local password without completing the
+// login flow. This is used by pending third-party account adoption flows before
+// the external identity has been bound.
+func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) {
+ if s == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email)))
+ if err != nil {
+ if errors.Is(err, ErrUserNotFound) {
+ return nil, ErrInvalidCredentials
+ }
+ return nil, ErrServiceUnavailable
+ }
+ if !user.IsActive() {
+ return nil, ErrUserNotActive
+ }
+ if !s.CheckPassword(password, user.PasswordHash) {
+ return nil, ErrInvalidCredentials
+ }
+ return user, nil
+}
+
+// RecordSuccessfulLogin updates last-login activity after a non-standard login
+// flow finishes with a real session.
+func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
+ if s != nil && s.userRepo != nil && userID > 0 {
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err == nil && user != nil && !isReservedEmail(user.Email) {
+ s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
+ }
+ }
+ s.touchUserLogin(ctx, userID)
+}
diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go
new file mode 100644
index 00000000..21d9d6e9
--- /dev/null
+++ b/backend/internal/service/auth_oauth_email_flow_test.go
@@ -0,0 +1,326 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type redeemCodeRepoStub struct {
+ codesByCode map[string]*RedeemCode
+ useCalls []struct {
+ id int64
+ userID int64
+ }
+ updateCalls []*RedeemCode
+}
+
+func (s *redeemCodeRepoStub) Create(context.Context, *RedeemCode) error {
+ panic("unexpected Create call")
+}
+
+func (s *redeemCodeRepoStub) CreateBatch(context.Context, []RedeemCode) error {
+ panic("unexpected CreateBatch call")
+}
+
+func (s *redeemCodeRepoStub) GetByID(context.Context, int64) (*RedeemCode, error) {
+ panic("unexpected GetByID call")
+}
+
+func (s *redeemCodeRepoStub) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
+ if s.codesByCode == nil {
+ return nil, ErrRedeemCodeNotFound
+ }
+ redeemCode, ok := s.codesByCode[code]
+ if !ok {
+ return nil, ErrRedeemCodeNotFound
+ }
+ cloned := *redeemCode
+ return &cloned, nil
+}
+
+func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ cloned := *code
+ s.updateCalls = append(s.updateCalls, &cloned)
+ if s.codesByCode == nil {
+ s.codesByCode = make(map[string]*RedeemCode)
+ }
+ s.codesByCode[cloned.Code] = &cloned
+ return nil
+}
+
+func (s *redeemCodeRepoStub) Delete(context.Context, int64) error {
+ panic("unexpected Delete call")
+}
+
+func (s *redeemCodeRepoStub) Use(_ context.Context, id, userID int64) error {
+ for code, redeemCode := range s.codesByCode {
+ if redeemCode.ID != id {
+ continue
+ }
+ now := time.Now().UTC()
+ redeemCode.Status = StatusUsed
+ redeemCode.UsedBy = &userID
+ redeemCode.UsedAt = &now
+ s.codesByCode[code] = redeemCode
+ s.useCalls = append(s.useCalls, struct {
+ id int64
+ userID int64
+ }{id: id, userID: userID})
+ return nil
+ }
+ return ErrRedeemCodeNotFound
+}
+
+func (s *redeemCodeRepoStub) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (s *redeemCodeRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (s *redeemCodeRepoStub) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
+ panic("unexpected ListByUser call")
+}
+
+func (s *redeemCodeRepoStub) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListByUserPaginated call")
+}
+
+func (s *redeemCodeRepoStub) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
+ panic("unexpected SumPositiveBalanceByUser call")
+}
+
+func newOAuthEmailFlowAuthService(
+ userRepo UserRepository,
+ redeemRepo RedeemCodeRepository,
+ refreshTokenCache RefreshTokenCache,
+ settings map[string]string,
+ emailCache EmailCache,
+) *AuthService {
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+
+ settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
+ emailService := NewEmailService(&settingRepoStub{values: settings}, emailCache)
+
+ return NewAuthService(
+ nil,
+ userRepo,
+ redeemRepo,
+ refreshTokenCache,
+ cfg,
+ settingService,
+ emailService,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+}
+
+func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFails(t *testing.T) {
+ userRepo := &userRepoStub{nextID: 42}
+ redeemRepo := &redeemCodeRepoStub{
+ codesByCode: map[string]*RedeemCode{
+ "INVITE123": {
+ ID: 7,
+ Code: "INVITE123",
+ Type: RedeemTypeInvitation,
+ Status: StatusUnused,
+ },
+ },
+ }
+ emailCache := &emailCacheStub{
+ data: &VerificationCodeData{
+ Code: "246810",
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ redeemRepo,
+ nil,
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyInvitationCodeEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "true",
+ },
+ emailCache,
+ )
+
+ tokenPair, user, err := authService.RegisterOAuthEmailAccount(
+ context.Background(),
+ "fresh@example.com",
+ "secret-123",
+ "246810",
+ "INVITE123",
+ "oidc",
+ )
+
+ require.Nil(t, tokenPair)
+ require.Nil(t, user)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "generate token pair")
+ require.Equal(t, []int64{42}, userRepo.deletedIDs)
+ require.Len(t, userRepo.created, 1)
+ require.Empty(t, redeemRepo.useCalls)
+ require.Empty(t, redeemRepo.updateCalls)
+}
+
+func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) {
+ userRepo := &userRepoStub{nextID: 42}
+ emailCache := &emailCacheStub{
+ data: &VerificationCodeData{
+ Code: "246810",
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ &redeemCodeRepoStub{},
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "true",
+ },
+ emailCache,
+ )
+
+ tokenPair, user, err := authService.RegisterOAuthEmailAccount(
+ context.Background(),
+ "fresh@example.com",
+ "secret-123",
+ "246810",
+ "",
+ " OIDC ",
+ )
+
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.NotNil(t, user)
+ require.Len(t, userRepo.created, 1)
+ require.Equal(t, "oidc", userRepo.created[0].SignupSource)
+}
+
+func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
+ userRepo := &userRepoStub{nextID: 43}
+ emailCache := &emailCacheStub{
+ data: &VerificationCodeData{
+ Code: "246810",
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ &redeemCodeRepoStub{},
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "true",
+ },
+ emailCache,
+ )
+
+ tokenPair, user, err := authService.RegisterOAuthEmailAccount(
+ context.Background(),
+ "fallback@example.com",
+ "secret-123",
+ "246810",
+ "",
+ "github",
+ )
+
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.NotNil(t, user)
+ require.Len(t, userRepo.created, 1)
+ require.Equal(t, "email", userRepo.created[0].SignupSource)
+}
+
+func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
+ userRepo := &userRepoStub{}
+ redeemRepo := &redeemCodeRepoStub{
+ codesByCode: map[string]*RedeemCode{
+ "INVITE123": {
+ ID: 7,
+ Code: "INVITE123",
+ Type: RedeemTypeInvitation,
+ Status: StatusUsed,
+ UsedBy: func() *int64 {
+ v := int64(42)
+ return &v
+ }(),
+ UsedAt: func() *time.Time {
+ v := time.Now().UTC()
+ return &v
+ }(),
+ },
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ redeemRepo,
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyInvitationCodeEnabled: "true",
+ },
+ &emailCacheStub{},
+ )
+
+ err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123")
+
+ require.NoError(t, err)
+ require.Equal(t, []int64{42}, userRepo.deletedIDs)
+ require.Len(t, redeemRepo.updateCalls, 1)
+ require.Equal(t, StatusUnused, redeemRepo.updateCalls[0].Status)
+ require.Nil(t, redeemRepo.updateCalls[0].UsedBy)
+ require.Nil(t, redeemRepo.updateCalls[0].UsedAt)
+}
+
+func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
+ userRepo := &userRepoStub{deleteErr: errors.New("delete failed")}
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ &redeemCodeRepoStub{},
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ },
+ &emailCacheStub{},
+ )
+
+ err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "")
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "delete created oauth user")
+}
diff --git a/backend/internal/service/auth_oauth_first_bind.go b/backend/internal/service/auth_oauth_first_bind.go
new file mode 100644
index 00000000..aa06e59f
--- /dev/null
+++ b/backend/internal/service/auth_oauth_first_bind.go
@@ -0,0 +1,104 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+
+ entsql "entgo.io/ent/dialect/sql"
+)
+
+// ApplyProviderDefaultSettingsOnFirstBind applies provider-specific bootstrap
+// settings the first time a user binds a third-party identity. The grant is
+// idempotent per user/provider pair.
+func (s *AuthService) ApplyProviderDefaultSettingsOnFirstBind(
+ ctx context.Context,
+ userID int64,
+ providerType string,
+) error {
+ if s == nil || s.entClient == nil || s.settingService == nil || userID <= 0 {
+ return nil
+ }
+
+ if dbent.TxFromContext(ctx) != nil {
+ return s.applyProviderDefaultSettingsOnFirstBind(ctx, userID, providerType)
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return fmt.Errorf("begin first bind defaults transaction: %w", err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := s.applyProviderDefaultSettingsOnFirstBind(txCtx, userID, providerType); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (s *AuthService) applyProviderDefaultSettingsOnFirstBind(
+ ctx context.Context,
+ userID int64,
+ providerType string,
+) error {
+ providerDefaults, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, providerType, true)
+ if err != nil {
+ return fmt.Errorf("load auth source defaults: %w", err)
+ }
+ if !enabled {
+ return nil
+ }
+
+ client := s.entClient
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ client = tx.Client()
+ }
+
+ var result entsql.Result
+ if err := client.Driver().Exec(
+ ctx,
+ `INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
+VALUES ($1, $2, $3)
+ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
+ []any{userID, strings.TrimSpace(providerType), "first_bind"},
+ &result,
+ ); err != nil {
+ return fmt.Errorf("record first bind provider grant: %w", err)
+ }
+
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return fmt.Errorf("read first bind provider grant result: %w", err)
+ }
+ if affected == 0 {
+ return nil
+ }
+
+ if providerDefaults.Balance != 0 {
+ if err := client.User.UpdateOneID(userID).AddBalance(providerDefaults.Balance).Exec(ctx); err != nil {
+ return fmt.Errorf("apply first bind balance default: %w", err)
+ }
+ }
+ if providerDefaults.Concurrency != 0 {
+ if err := client.User.UpdateOneID(userID).AddConcurrency(providerDefaults.Concurrency).Exec(ctx); err != nil {
+ return fmt.Errorf("apply first bind concurrency default: %w", err)
+ }
+ }
+ if s.defaultSubAssigner != nil {
+ for _, item := range providerDefaults.Subscriptions {
+ if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
+ UserID: userID,
+ GroupID: item.GroupID,
+ ValidityDays: item.ValidityDays,
+ Notes: "auto assigned by first bind defaults",
+ }); err != nil {
+ return fmt.Errorf("apply first bind subscription default: %w", err)
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go
new file mode 100644
index 00000000..6e69c121
--- /dev/null
+++ b/backend/internal/service/auth_pending_identity_service.go
@@ -0,0 +1,543 @@
+package service
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "hash/fnv"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "entgo.io/ent/dialect"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+
+ entsql "entgo.io/ent/dialect/sql"
+)
+
+var (
+ ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found")
+ ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired")
+ ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used")
+ ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid")
+ ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired")
+ ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used")
+ ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session")
+)
+
+const (
+ defaultPendingAuthTTL = 15 * time.Minute
+ defaultPendingAuthCompletionTTL = 5 * time.Minute
+)
+
+type PendingAuthIdentityKey struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+}
+
+type CreatePendingAuthSessionInput struct {
+ SessionToken string
+ Intent string
+ Identity PendingAuthIdentityKey
+ TargetUserID *int64
+ RedirectTo string
+ ResolvedEmail string
+ RegistrationPasswordHash string
+ BrowserSessionKey string
+ UpstreamIdentityClaims map[string]any
+ LocalFlowState map[string]any
+ ExpiresAt time.Time
+}
+
+type IssuePendingAuthCompletionCodeInput struct {
+ PendingAuthSessionID int64
+ BrowserSessionKey string
+ TTL time.Duration
+}
+
+type IssuePendingAuthCompletionCodeResult struct {
+ Code string
+ ExpiresAt time.Time
+}
+
+type PendingIdentityAdoptionDecisionInput struct {
+ PendingAuthSessionID int64
+ IdentityID *int64
+ AdoptDisplayName bool
+ AdoptAvatar bool
+}
+
+type AuthPendingIdentityService struct {
+ entClient *dbent.Client
+}
+
+var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry()
+
+type authPendingIdentityScopedKeyLockRegistry struct {
+ mu sync.Mutex
+ locks map[string]*authPendingIdentityScopedKeyLockEntry
+}
+
+type authPendingIdentityScopedKeyLockEntry struct {
+ mu sync.Mutex
+ refs int
+}
+
+func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry {
+ return &authPendingIdentityScopedKeyLockRegistry{
+ locks: make(map[string]*authPendingIdentityScopedKeyLockEntry),
+ }
+}
+
+func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() {
+ normalized := normalizeAuthPendingIdentityLockKeys(keys...)
+ if len(normalized) == 0 {
+ return func() {}
+ }
+
+ entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized))
+ r.mu.Lock()
+ for _, key := range normalized {
+ entry := r.locks[key]
+ if entry == nil {
+ entry = &authPendingIdentityScopedKeyLockEntry{}
+ r.locks[key] = entry
+ }
+ entry.refs++
+ entries = append(entries, entry)
+ }
+ r.mu.Unlock()
+
+ for _, entry := range entries {
+ entry.mu.Lock()
+ }
+
+ return func() {
+ for i := len(entries) - 1; i >= 0; i-- {
+ entries[i].mu.Unlock()
+ }
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ for idx, key := range normalized {
+ entry := entries[idx]
+ entry.refs--
+ if entry.refs == 0 {
+ delete(r.locks, key)
+ }
+ }
+ }
+}
+
+func normalizeAuthPendingIdentityLockKeys(keys ...string) []string {
+ if len(keys) == 0 {
+ return nil
+ }
+
+ deduped := make(map[string]struct{}, len(keys))
+ for _, key := range keys {
+ trimmed := strings.TrimSpace(key)
+ if trimmed == "" {
+ continue
+ }
+ deduped[trimmed] = struct{}{}
+ }
+ if len(deduped) == 0 {
+ return nil
+ }
+
+ normalized := make([]string, 0, len(deduped))
+ for key := range deduped {
+ normalized = append(normalized, key)
+ }
+ sort.Strings(normalized)
+ return normalized
+}
+
+func authPendingIdentityAdvisoryLockHash(key string) int64 {
+ hasher := fnv.New64a()
+ _, _ = hasher.Write([]byte(key))
+ return int64(hasher.Sum64())
+}
+
+func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) {
+ release := authPendingIdentityScopedKeyLocks.lock(keys...)
+ normalized := normalizeAuthPendingIdentityLockKeys(keys...)
+ if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres {
+ return release, nil
+ }
+
+ for _, key := range normalized {
+ var rows entsql.Rows
+ if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil {
+ release()
+ return nil, err
+ }
+ _ = rows.Close()
+ }
+
+ return release, nil
+}
+
+func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
+ keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)}
+ if identityID != nil && *identityID > 0 {
+ keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID))
+ }
+ return keys
+}
+
+func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
+ return &AuthPendingIdentityService{entClient: entClient}
+}
+
+func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ sessionToken := strings.TrimSpace(input.SessionToken)
+ if sessionToken == "" {
+ var err error
+ sessionToken, err = randomOpaqueToken(24)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ expiresAt := input.ExpiresAt.UTC()
+ if expiresAt.IsZero() {
+ expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL)
+ }
+
+ create := s.entClient.PendingAuthSession.Create().
+ SetSessionToken(sessionToken).
+ SetIntent(strings.TrimSpace(input.Intent)).
+ SetProviderType(strings.TrimSpace(input.Identity.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)).
+ SetRedirectTo(strings.TrimSpace(input.RedirectTo)).
+ SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)).
+ SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)).
+ SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)).
+ SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)).
+ SetLocalFlowState(copyPendingMap(input.LocalFlowState)).
+ SetExpiresAt(expiresAt)
+ if input.TargetUserID != nil {
+ create = create.SetTargetUserID(*input.TargetUserID)
+ }
+ return create.Save(ctx)
+}
+
+func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, err
+ }
+
+ code, err := randomOpaqueToken(24)
+ if err != nil {
+ return nil, err
+ }
+ ttl := input.TTL
+ if ttl <= 0 {
+ ttl = defaultPendingAuthCompletionTTL
+ }
+ expiresAt := time.Now().UTC().Add(ttl)
+
+ update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
+ SetCompletionCodeHash(hashPendingAuthCode(code)).
+ SetCompletionCodeExpiresAt(expiresAt)
+ if strings.TrimSpace(input.BrowserSessionKey) != "" {
+ update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey))
+ }
+ if _, err := update.Save(ctx); err != nil {
+ return nil, err
+ }
+
+ return &IssuePendingAuthCompletionCodeResult{
+ Code: code,
+ ExpiresAt: expiresAt,
+ }, nil
+}
+
+func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode))
+ session, err := s.entClient.PendingAuthSession.Query().
+ Where(pendingauthsession.CompletionCodeHashEQ(codeHash)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthCodeInvalid
+ }
+ return nil, err
+ }
+
+ return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed)
+}
+
+func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.getBrowserSession(ctx, sessionToken)
+ if err != nil {
+ return nil, err
+ }
+
+ return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
+}
+
+func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.getBrowserSession(ctx, sessionToken)
+ if err != nil {
+ return nil, err
+ }
+ if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil {
+ return nil, err
+ }
+ return session, nil
+}
+
+func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ sessionToken = strings.TrimSpace(sessionToken)
+ if sessionToken == "" {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+
+ session, err := s.entClient.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(sessionToken)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, err
+ }
+ return session, nil
+}
+
+func (s *AuthPendingIdentityService) consumeSession(
+ ctx context.Context,
+ session *dbent.PendingAuthSession,
+ browserSessionKey string,
+ expiredErr error,
+ consumedErr error,
+) (*dbent.PendingAuthSession, error) {
+ if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil {
+ return nil, err
+ }
+
+ sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState(session.LocalFlowState)
+ now := time.Now().UTC()
+ update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
+ Where(
+ pendingauthsession.ConsumedAtIsNil(),
+ pendingauthsession.ExpiresAtGTE(now),
+ pendingauthsession.Or(
+ pendingauthsession.CompletionCodeExpiresAtIsNil(),
+ pendingauthsession.CompletionCodeExpiresAtGTE(now),
+ ),
+ ).
+ SetConsumedAt(now).
+ SetLocalFlowState(sanitizedLocalFlowState).
+ SetCompletionCodeHash("").
+ ClearCompletionCodeExpiresAt()
+ if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" {
+ update = update.Where(pendingauthsession.BrowserSessionKeyEQ(expectedBrowserSessionKey))
+ }
+ updated, err := update.Save(ctx)
+ if err == nil {
+ return updated, nil
+ }
+ if !dbent.IsNotFound(err) {
+ return nil, err
+ }
+
+ current, currentErr := s.entClient.PendingAuthSession.Get(ctx, session.ID)
+ if currentErr != nil {
+ if dbent.IsNotFound(currentErr) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, currentErr
+ }
+ if err := validatePendingSessionState(current, browserSessionKey, expiredErr, consumedErr); err != nil {
+ return nil, err
+ }
+ return nil, consumedErr
+}
+
+func sanitizePendingAuthLocalFlowState(localFlowState map[string]any) map[string]any {
+ sanitized := copyPendingMap(localFlowState)
+ if len(sanitized) == 0 {
+ return sanitized
+ }
+
+ rawCompletion, ok := sanitized["completion_response"]
+ if !ok {
+ return sanitized
+ }
+ completion, ok := rawCompletion.(map[string]any)
+ if !ok {
+ return sanitized
+ }
+
+ cleanedCompletion := copyPendingMap(completion)
+ for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
+ delete(cleanedCompletion, key)
+ }
+ sanitized["completion_response"] = cleanedCompletion
+ return sanitized
+}
+
+func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
+ if session == nil {
+ return ErrPendingAuthSessionNotFound
+ }
+
+ now := time.Now().UTC()
+ if session.ConsumedAt != nil {
+ return consumedErr
+ }
+ if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
+ return expiredErr
+ }
+ if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) {
+ return expiredErr
+ }
+ if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
+ return ErrPendingAuthBrowserMismatch
+ }
+ return nil
+}
+
+func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
+ return nil, err
+ }
+
+ client := s.entClient
+ txCtx := ctx
+ if err == nil {
+ defer func() { _ = tx.Rollback() }()
+ client = tx.Client()
+ txCtx = dbent.NewTxContext(ctx, tx)
+ } else if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
+ client = existingTx.Client()
+ }
+
+ releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...)
+ if err != nil {
+ return nil, err
+ }
+ defer releaseLocks()
+
+ if input.IdentityID != nil && *input.IdentityID > 0 {
+ if _, err := client.IdentityAdoptionDecision.Update().
+ Where(
+ identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
+ dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
+ col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
+ s.Where(entsql.Or(
+ entsql.IsNull(col),
+ entsql.NEQ(col, input.PendingAuthSessionID),
+ ))
+ }),
+ ).
+ ClearIdentityID().
+ Save(txCtx); err != nil {
+ return nil, err
+ }
+ }
+
+ create := client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(input.PendingAuthSessionID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar).
+ SetDecidedAt(time.Now().UTC())
+ if input.IdentityID != nil && *input.IdentityID > 0 {
+ create = create.SetIdentityID(*input.IdentityID)
+ }
+
+ decisionID, err := create.
+ OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
+ UpdateNewValues().
+ ID(txCtx)
+ if err != nil {
+ return nil, err
+ }
+
+ decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID)
+ if err != nil {
+ return nil, err
+ }
+
+ if tx != nil {
+ if err := tx.Commit(); err != nil {
+ return nil, err
+ }
+ }
+
+ return decision, nil
+}
+
+func copyPendingMap(in map[string]any) map[string]any {
+ if len(in) == 0 {
+ return map[string]any{}
+ }
+ out := make(map[string]any, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func randomOpaqueToken(byteLen int) (string, error) {
+ if byteLen <= 0 {
+ byteLen = 16
+ }
+ buf := make([]byte, byteLen)
+ if _, err := rand.Read(buf); err != nil {
+ return "", err
+ }
+ return hex.EncodeToString(buf), nil
+}
+
+func hashPendingAuthCode(code string) string {
+ sum := sha256.Sum256([]byte(code))
+ return hex.EncodeToString(sum[:])
+}
diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go
new file mode 100644
index 00000000..555bb0e7
--- /dev/null
+++ b/backend/internal/service/auth_pending_identity_service_test.go
@@ -0,0 +1,526 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "sync"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ return NewAuthPendingIdentityService(client), client
+}
+
+func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ targetUser, err := client.User.Create().
+ SetEmail("pending-target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-123",
+ },
+ TargetUserID: &targetUser.ID,
+ RedirectTo: "/profile",
+ ResolvedEmail: "user@example.com",
+ BrowserSessionKey: "browser-1",
+ UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"},
+ LocalFlowState: map[string]any{"step": "email_required"},
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, session.SessionToken)
+ require.Equal(t, "bind_current_user", session.Intent)
+ require.Equal(t, "wechat", session.ProviderType)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, targetUser.ID, *session.TargetUserID)
+ require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"])
+ require.Equal(t, "email_required", session.LocalFlowState["step"])
+}
+
+func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ BrowserSessionKey: "browser-expected",
+ UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"},
+ LocalFlowState: map[string]any{"step": "pending"},
+ })
+ require.NoError(t, err)
+
+ issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
+ PendingAuthSessionID: session.ID,
+ BrowserSessionKey: "browser-expected",
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, issued.Code)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other")
+ require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
+
+ consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+ require.Empty(t, consumed.CompletionCodeHash)
+ require.Nil(t, consumed.CompletionCodeExpiresAt)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
+ require.ErrorIs(t, err, ErrPendingAuthCodeInvalid)
+}
+
+func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-1",
+ },
+ BrowserSessionKey: "browser-expired",
+ })
+ require.NoError(t, err)
+
+ issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
+ PendingAuthSessionID: session.ID,
+ BrowserSessionKey: "browser-expired",
+ TTL: time.Second,
+ })
+ require.NoError(t, err)
+
+ _, err = client.PendingAuthSession.UpdateOneID(session.ID).
+ SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired")
+ require.ErrorIs(t, err, ErrPendingAuthCodeExpired)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("adoption@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-open").
+ SetProviderSubject("union-adoption").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption",
+ },
+ })
+ require.NoError(t, err)
+
+ first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ require.NoError(t, err)
+ require.True(t, first.AdoptDisplayName)
+ require.False(t, first.AdoptAvatar)
+ require.Nil(t, first.IdentityID)
+
+ second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, first.ID, second.ID)
+ require.NotNil(t, second.IdentityID)
+ require.Equal(t, identity.ID, *second.IdentityID)
+ require.True(t, second.AdoptAvatar)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIdentityReference(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("adoption-reassign@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-open").
+ SetProviderSubject("union-reassign").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ firstSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-reassign",
+ },
+ })
+ require.NoError(t, err)
+
+ firstDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: firstSession.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, firstDecision.IdentityID)
+ require.Equal(t, identity.ID, *firstDecision.IdentityID)
+
+ secondSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-reassign",
+ },
+ })
+ require.NoError(t, err)
+
+ secondDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: secondSession.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: false,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, secondDecision.IdentityID)
+ require.Equal(t, identity.ID, *secondDecision.IdentityID)
+
+ reloadedFirst, err := client.IdentityAdoptionDecision.Get(ctx, firstDecision.ID)
+ require.NoError(t, err)
+ require.Nil(t, reloadedFirst.IdentityID)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("adoption-concurrent@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-concurrent").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-concurrent",
+ },
+ })
+ require.NoError(t, err)
+
+ firstCreateStarted := make(chan struct{})
+ releaseFirstCreate := make(chan struct{})
+ var firstCreate sync.Once
+ client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
+ return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
+ blocked := false
+ if m.Op().Is(dbent.OpCreate) {
+ firstCreate.Do(func() {
+ blocked = true
+ close(firstCreateStarted)
+ })
+ }
+ if blocked {
+ <-releaseFirstCreate
+ }
+ return next.Mutate(ctx, m)
+ })
+ })
+
+ type adoptionResult struct {
+ decision *dbent.IdentityAdoptionDecision
+ err error
+ }
+
+ input := PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ }
+
+ results := make(chan adoptionResult, 2)
+ go func() {
+ decision, err := svc.UpsertAdoptionDecision(ctx, input)
+ results <- adoptionResult{decision: decision, err: err}
+ }()
+
+ <-firstCreateStarted
+
+ go func() {
+ decision, err := svc.UpsertAdoptionDecision(ctx, input)
+ results <- adoptionResult{decision: decision, err: err}
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ close(releaseFirstCreate)
+
+ first := <-results
+ second := <-results
+
+ require.NoError(t, first.err)
+ require.NoError(t, second.err)
+ require.NotNil(t, first.decision)
+ require.NotNil(t, second.decision)
+ require.Equal(t, first.decision.ID, second.decision.ID)
+
+ count, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, count)
+
+ loaded, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, loaded.IdentityID)
+ require.Equal(t, identity.ID, *loaded.IdentityID)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) {
+ t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL")
+
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("legacy-null-session@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("legacy-null-session").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.ExecContext(
+ ctx,
+ `INSERT INTO identity_adoption_decisions
+ (identity_id, adopt_display_name, adopt_avatar, decided_at, created_at, updated_at, pending_auth_session_id)
+ VALUES (?, ?, ?, ?, ?, ?, NULL)`,
+ identity.ID,
+ true,
+ false,
+ time.Now().UTC(),
+ time.Now().UTC(),
+ time.Now().UTC(),
+ )
+ require.NoError(t, err)
+ legacyDecision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.IdentityIDEQ(identity.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, legacyDecision.IdentityID)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "legacy-null-session",
+ },
+ })
+ require.NoError(t, err)
+
+ decision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: false,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+
+ reloadedLegacy, err := client.IdentityAdoptionDecision.Get(ctx, legacyDecision.ID)
+ require.NoError(t, err)
+ require.Nil(t, reloadedLegacy.IdentityID)
+}
+
+func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "subject-session-token",
+ },
+ BrowserSessionKey: "browser-session",
+ LocalFlowState: map[string]any{
+ "completion_response": map[string]any{
+ "access_token": "token",
+ },
+ },
+ })
+ require.NoError(t, err)
+
+ _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other")
+ require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
+
+ consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+
+ _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
+}
+
+func TestAuthPendingIdentityService_ConsumeBrowserSessionRejectsStaleLoadedSessionReplay(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "stale-replay-subject",
+ },
+ BrowserSessionKey: "browser-session",
+ })
+ require.NoError(t, err)
+
+ loaded, err := svc.getBrowserSession(ctx, session.SessionToken)
+ require.NoError(t, err)
+
+ consumed, err := svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+
+ _, err = svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
+ require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
+}
+
+func TestAuthPendingIdentityService_ConsumeBrowserSessionScrubsLegacyCompletionTokens(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "legacy-token-subject",
+ },
+ BrowserSessionKey: "browser-session",
+ LocalFlowState: map[string]any{
+ "completion_response": map[string]any{
+ "access_token": "legacy-access-token",
+ "refresh_token": "legacy-refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ "redirect": "/dashboard",
+ },
+ },
+ })
+ require.NoError(t, err)
+
+ consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+
+ stored, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+
+ completion, ok := stored.LocalFlowState["completion_response"].(map[string]any)
+ require.True(t, ok)
+ require.NotContains(t, completion, "access_token")
+ require.NotContains(t, completion, "refresh_token")
+ require.NotContains(t, completion, "expires_in")
+ require.NotContains(t, completion, "token_type")
+ require.Equal(t, "/dashboard", completion["redirect"])
+}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index fd28cd42..b1adf071 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"crypto/sha256"
+ "encoding/binary"
"encoding/hex"
"errors"
"fmt"
@@ -13,6 +14,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -70,6 +72,7 @@ type AuthService struct {
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
+ affiliateService *AffiliateService
defaultSubAssigner DefaultSubscriptionAssigner
}
@@ -77,6 +80,12 @@ type DefaultSubscriptionAssigner interface {
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
}
+type signupGrantPlan struct {
+ Balance float64
+ Concurrency int
+ Subscriptions []DefaultSubscriptionSetting
+}
+
// NewAuthService 创建认证服务实例
func NewAuthService(
entClient *dbent.Client,
@@ -90,6 +99,7 @@ func NewAuthService(
emailQueueService *EmailQueueService,
promoService *PromoService,
defaultSubAssigner DefaultSubscriptionAssigner,
+ affiliateService *AffiliateService,
) *AuthService {
return &AuthService{
entClient: entClient,
@@ -102,17 +112,25 @@ func NewAuthService(
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
+ affiliateService: affiliateService,
defaultSubAssigner: defaultSubAssigner,
}
}
-// Register 用户注册,返回token和用户
-func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
- return s.RegisterWithVerification(ctx, email, password, "", "", "")
+func (s *AuthService) EntClient() *dbent.Client {
+ if s == nil {
+ return nil
+ }
+ return s.entClient
}
-// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户
-func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
+// Register 用户注册,返回token和用户
+func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
+ return s.RegisterWithVerification(ctx, email, password, "", "", "", "")
+}
+
+// RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利码),返回token和用户。
+func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode, affiliateCode string) (string, *User, error) {
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
@@ -179,12 +197,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, fmt.Errorf("hash password: %w", err)
}
- // 获取默认配置
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
+ grantPlan := s.resolveSignupGrantPlan(ctx, "email")
+
+ // 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。
+ var defaultRPMLimit int
if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
+ defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
// 创建用户
@@ -192,8 +210,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
Email: email,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
+ RPMLimit: defaultRPMLimit,
Status: StatusActive,
}
@@ -205,7 +224,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, "email", true)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ if s.affiliateService != nil {
+ if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
+ }
+ if code := strings.TrimSpace(affiliateCode); code != "" {
+ if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil {
+ // 邀请返利码绑定失败不影响注册,只记录日志
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err)
+ }
+ }
+ }
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
@@ -469,12 +500,11 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return "", nil, fmt.Errorf("hash password: %w", err)
}
- // 新用户默认值。
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
+ signupSource := inferLegacySignupSource(email)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+ var defaultRPMLimit int
if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
+ defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{
@@ -482,9 +512,11 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
+ RPMLimit: defaultRPMLimit,
Status: StatusActive,
+ SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -501,7 +533,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -520,7 +553,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
-
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
@@ -531,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
-func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
+// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
+func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
@@ -584,11 +617,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, fmt.Errorf("hash password: %w", err)
}
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
+ signupSource := inferLegacySignupSource(email)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+ var defaultRPMLimit int
if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
+ defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{
@@ -596,9 +629,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
+ RPMLimit: defaultRPMLimit,
Status: StatusActive,
+ SignupSource: signupSource,
}
if s.entClient != nil && invitationRedeemCode != nil {
@@ -630,7 +665,9 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
}
} else {
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -646,7 +683,9 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
@@ -670,7 +709,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
-
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)
@@ -678,80 +716,289 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
-// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
-const pendingOAuthTokenTTL = 10 * time.Minute
-
-// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
-const pendingOAuthPurpose = "pending_oauth_registration"
-
-type pendingOAuthClaims struct {
- Email string `json:"email"`
- Username string `json:"username"`
- Purpose string `json:"purpose"`
- jwt.RegisteredClaims
-}
-
-// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
-// while waiting for the user to supply an invitation code.
-func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: email,
- Username: username,
- Purpose: pendingOAuthPurpose,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- return token.SignedString([]byte(s.cfg.JWT.Secret))
-}
-
-// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
-// Returns ErrInvalidToken when the token is invalid or expired.
-func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
- if len(tokenStr) > maxTokenLength {
- return "", "", ErrInvalidToken
- }
- parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
- token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
- if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
- }
- return []byte(s.cfg.JWT.Secret), nil
- })
- if parseErr != nil {
- return "", "", ErrInvalidToken
- }
- claims, ok := token.Claims.(*pendingOAuthClaims)
- if !ok || !token.Valid {
- return "", "", ErrInvalidToken
- }
- if claims.Purpose != pendingOAuthPurpose {
- return "", "", ErrInvalidToken
- }
- return claims.Email, claims.Username, nil
-}
-
-func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
+func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
}
- items := s.settingService.GetDefaultSubscriptions(ctx)
for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
- Notes: "auto assigned by default user subscriptions setting",
+ Notes: notes,
}); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
}
}
}
+func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan {
+ plan := signupGrantPlan{}
+ if s != nil && s.cfg != nil {
+ plan.Balance = s.cfg.Default.UserBalance
+ plan.Concurrency = s.cfg.Default.UserConcurrency
+ }
+ if s == nil || s.settingService == nil {
+ return plan
+ }
+
+ plan.Balance = s.settingService.GetDefaultBalance(ctx)
+ plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
+ plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
+
+ resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
+ return plan
+ }
+ if !enabled {
+ return plan
+ }
+
+ plan.Balance = resolved.Balance
+ plan.Concurrency = resolved.Concurrency
+ plan.Subscriptions = resolved.Subscriptions
+ return plan
+}
+
+func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) {
+ if defaults == nil {
+ return ProviderDefaultGrantSettings{}, false
+ }
+
+ switch strings.ToLower(strings.TrimSpace(signupSource)) {
+ case "email":
+ return defaults.Email, true
+ case "linuxdo":
+ return defaults.LinuxDo, true
+ case "oidc":
+ return defaults.OIDC, true
+ case "wechat":
+ return defaults.WeChat, true
+ default:
+ return ProviderDefaultGrantSettings{}, false
+ }
+}
+
+// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
+// for an OAuth-registered user. Failures are logged but never block registration.
+func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) {
+ if s.affiliateService == nil || userID <= 0 {
+ return
+ }
+ if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err)
+ }
+ if code := strings.TrimSpace(affiliateCode); code != "" {
+ if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err)
+ }
+ }
+}
+
+func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
+ if user == nil || user.ID <= 0 {
+ return
+ }
+
+ if strings.TrimSpace(signupSource) == "" {
+ signupSource = "email"
+ }
+ s.updateUserSignupSource(ctx, user.ID, signupSource)
+
+ if touchLogin {
+ s.touchUserLogin(ctx, user.ID)
+ }
+}
+
+func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return
+ }
+ if strings.TrimSpace(signupSource) == "" {
+ return
+ }
+ if err := s.entClient.User.UpdateOneID(userID).
+ SetSignupSource(signupSource).
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
+ }
+}
+
+func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return
+ }
+ now := time.Now().UTC()
+ if err := s.entClient.User.UpdateOneID(userID).
+ SetLastLoginAt(now).
+ SetLastActiveAt(now).
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
+ }
+}
+
+func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context, user *User) {
+ if s == nil || user == nil || user.ID <= 0 {
+ return
+ }
+ identity, created := s.ensureEmailAuthIdentity(ctx, user, "auth_service_login_backfill")
+ if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
+ }
+ }
+}
+
+func (s *AuthService) shouldApplyEmailFirstBindDefaults(
+ ctx context.Context,
+ userID int64,
+ identity *dbent.AuthIdentity,
+ created bool,
+) bool {
+ source := emailAuthIdentitySource(identity.Metadata)
+ if source == "auth_service_login_backfill" {
+ return false
+ }
+ if created {
+ return true
+ }
+ if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID {
+ return false
+ }
+ if source != "auth_service_dual_write" {
+ return false
+ }
+
+ hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind")
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err)
+ return false
+ }
+ return !hasGrant
+}
+
+func emailAuthIdentitySource(metadata map[string]any) string {
+ if len(metadata) == 0 {
+ return ""
+ }
+ raw, ok := metadata["source"]
+ if !ok {
+ return ""
+ }
+ return strings.TrimSpace(fmt.Sprint(raw))
+}
+
+func (s *AuthService) hasProviderGrantRecord(
+ ctx context.Context,
+ userID int64,
+ providerType string,
+ grantReason string,
+) (bool, error) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return false, nil
+ }
+
+ rows, err := s.entClient.QueryContext(
+ ctx,
+ `SELECT 1 FROM user_provider_default_grants WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3 LIMIT 1`,
+ userID,
+ strings.TrimSpace(providerType),
+ strings.TrimSpace(grantReason),
+ )
+ if err != nil {
+ return false, err
+ }
+ defer func() { _ = rows.Close() }()
+ return rows.Next(), rows.Err()
+}
+
+func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, source string) (*dbent.AuthIdentity, bool) {
+ if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
+ return nil, false
+ }
+
+ email := strings.ToLower(strings.TrimSpace(user.Email))
+ if email == "" || isReservedEmail(email) {
+ return nil, false
+ }
+ if strings.TrimSpace(source) == "" {
+ source = "auth_service_dual_write"
+ }
+
+ client := s.entClient
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ client = tx.Client()
+ }
+
+ buildQuery := func() *dbent.AuthIdentityQuery {
+ return client.AuthIdentity.Query().Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(email),
+ )
+ }
+
+ existed, err := buildQuery().Exist(ctx)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return nil, false
+ }
+
+ if !existed {
+ if err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(email).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{
+ "source": strings.TrimSpace(source),
+ }).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ if isSQLNoRowsError(err) {
+ return nil, false
+ }
+ }
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return nil, false
+ }
+ }
+
+ identity, err := buildQuery().Only(ctx)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return nil, false
+ }
+ if identity.UserID != user.ID {
+ logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
+ return nil, false
+ }
+
+ return identity, !existed
+}
+
+func inferLegacySignupSource(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ switch {
+ case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
+ return "linuxdo"
+ case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
+ return "oidc"
+ case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
+ return "wechat"
+ default:
+ return "email"
+ }
+}
+
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
if s.settingService == nil {
return nil
@@ -834,7 +1081,8 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
- strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
+ strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token
@@ -853,7 +1101,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
UserID: user.ID,
Email: user.Email,
Role: user.Role,
- TokenVersion: user.TokenVersion,
+ TokenVersion: resolvedTokenVersion(user),
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
@@ -919,7 +1167,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// Security: Check TokenVersion to prevent refreshing revoked tokens
// This ensures tokens issued before a password change cannot be refreshed
- if claims.TokenVersion != user.TokenVersion {
+ if claims.TokenVersion != resolvedTokenVersion(user) {
return "", ErrTokenRevoked
}
@@ -1147,7 +1395,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
data := &RefreshTokenData{
UserID: user.ID,
- TokenVersion: user.TokenVersion,
+ TokenVersion: resolvedTokenVersion(user),
FamilyID: familyID,
CreatedAt: now,
ExpiresAt: now.Add(ttl),
@@ -1227,7 +1475,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
}
// 检查TokenVersion(密码更改后所有Token失效)
- if data.TokenVersion != user.TokenVersion {
+ if data.TokenVersion != resolvedTokenVersion(user) {
// TokenVersion不匹配,撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrTokenRevoked
@@ -1272,8 +1520,42 @@ func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) e
return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID)
}
+// RevokeAllUserTokens invalidates both stateless access tokens and refresh sessions.
+// Access/refresh token verification both depend on TokenVersion, so bumping it provides
+// immediate revocation even if refresh-token cache cleanup later fails.
+func (s *AuthService) RevokeAllUserTokens(ctx context.Context, userID int64) error {
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return fmt.Errorf("get user: %w", err)
+ }
+
+ user.TokenVersion++
+ if err := s.userRepo.Update(ctx, user); err != nil {
+ return fmt.Errorf("update user: %w", err)
+ }
+
+ if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after token invalidation for user %d: %v", userID, err)
+ }
+ return nil
+}
+
// hashToken 计算Token的SHA256哈希
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
+
+func resolvedTokenVersion(user *User) int64 {
+ if user == nil {
+ return 0
+ }
+ if user.TokenVersionResolved {
+ return user.TokenVersion
+ }
+
+ material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash
+ sum := sha256.Sum256([]byte(material))
+ fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
+ return user.TokenVersion ^ fingerprint
+}
diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go
new file mode 100644
index 00000000..ea2308f7
--- /dev/null
+++ b/backend/internal/service/auth_service_email_bind_test.go
@@ -0,0 +1,853 @@
+//go:build unit
+
+package service_test
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "sync"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type emailBindDefaultSubAssignerStub struct {
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
+}
+
+type flakyEmailBindDefaultSubAssignerStub struct {
+ err error
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *flakyEmailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return nil, false, s.err
+}
+
+func newAuthServiceForEmailBind(
+ t *testing.T,
+ settings map[string]string,
+ emailCache service.EmailCache,
+ defaultSubAssigner service.DefaultSubscriptionAssigner,
+) (*service.AuthService, service.UserRepository, *dbent.Client) {
+ return newAuthServiceForEmailBindWithRefreshCache(t, settings, emailCache, defaultSubAssigner, nil)
+}
+
+func newAuthServiceForEmailBindWithRefreshCache(
+ t *testing.T,
+ settings map[string]string,
+ emailCache service.EmailCache,
+ defaultSubAssigner service.DefaultSubscriptionAssigner,
+ refreshTokenCache service.RefreshTokenCache,
+) (*service.AuthService, service.UserRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_service_email_bind?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ provider_type TEXT NOT NULL,
+ grant_reason TEXT NOT NULL DEFAULT 'first_bind',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, provider_type, grant_reason)
+)`)
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ repo := repository.NewUserRepository(client, db)
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-bind-email-secret",
+ ExpireHour: 1,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+
+ settingRepo := &emailBindSettingRepoStub{values: settings}
+ settingSvc := service.NewSettingService(settingRepo, cfg)
+
+ var emailSvc *service.EmailService
+ if emailCache != nil {
+ emailSvc = service.NewEmailService(settingRepo, emailCache)
+ }
+
+ svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil)
+ return svc, repo, client
+}
+
+func TestAuthServiceBindEmailIdentity_UpdatesEmailAndAppliesFirstBindDefaults(t *testing.T) {
+ assigner := &emailBindDefaultSubAssignerStub{}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ user, err := client.User.Create().
+ SetEmail("legacy-user" + service.LinuxDoConnectSyntheticEmailDomain).
+ SetUsername("legacy-user").
+ SetPasswordHash("old-hash").
+ SetBalance(2.5).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, " NewEmail@Example.com ", "123456", "new-password")
+ require.NoError(t, err)
+ require.NotNil(t, updatedUser)
+ require.Equal(t, "newemail@example.com", updatedUser.Email)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "newemail@example.com", storedUser.Email)
+ require.Equal(t, 11.0, storedUser.Balance)
+ require.Equal(t, 5, storedUser.Concurrency)
+ require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("newemail@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, user.ID, assigner.calls[0].UserID)
+ require.Equal(t, int64(11), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ sourceUser, err := client.User.Create().
+ SetEmail("source-user" + service.OIDCConnectSyntheticEmailDomain).
+ SetUsername("source-user").
+ SetPasswordHash("old-hash").
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.User.Create().
+ SetEmail("taken@example.com").
+ SetUsername("taken-user").
+ SetPasswordHash("hash").
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, sourceUser.ID, "taken@example.com", "123456", "new-password")
+ require.ErrorIs(t, err, service.ErrEmailExists)
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, sourceUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, "source-user"+service.OIDCConnectSyntheticEmailDomain, storedUser.Email)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RollsBackWhenFirstBindDefaultsFail(t *testing.T) {
+ assigner := &flakyEmailBindDefaultSubAssignerStub{err: errors.New("temporary assign failure")}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ originalEmail := "legacy-rollback" + service.LinuxDoConnectSyntheticEmailDomain
+ user, err := client.User.Create().
+ SetEmail(originalEmail).
+ SetUsername("legacy-rollback").
+ SetPasswordHash("old-hash").
+ SetBalance(2.5).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "rollback@example.com", "123456", "new-password")
+ require.ErrorContains(t, err, "apply email first bind defaults")
+ require.ErrorContains(t, err, "temporary assign failure")
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, originalEmail, storedUser.Email)
+ require.Equal(t, "old-hash", storedUser.PasswordHash)
+ require.Equal(t, 2.5, storedUser.Balance)
+ require.Equal(t, 1, storedUser.Concurrency)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("rollback@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, identityCount)
+
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ user, err := client.User.Create().
+ SetEmail("source-user@example.com").
+ SetUsername("source-user").
+ SetPasswordHash("old-hash").
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "reserved"+service.LinuxDoConnectSyntheticEmailDomain, "123456", "new-password")
+ require.ErrorIs(t, err, service.ErrEmailReserved)
+ require.Nil(t, updatedUser)
+}
+
+func TestAuthServiceBindEmailIdentity_ReplacesBoundEmailAndSkipsFirstBindDefaults(t *testing.T) {
+ assigner := &emailBindDefaultSubAssignerStub{}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ hashedPassword, err := svc.HashPassword("current-password")
+ require.NoError(t, err)
+
+ user, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(hashedPassword).
+ SetBalance(7.5).
+ SetConcurrency(3).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ require.NoError(t, client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("current@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "test"}).
+ Exec(ctx))
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "current-password")
+ require.NoError(t, err)
+ require.NotNil(t, updatedUser)
+ require.Equal(t, "new@example.com", updatedUser.Email)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "new@example.com", storedUser.Email)
+ require.Equal(t, 7.5, storedUser.Balance)
+ require.Equal(t, 3, storedUser.Concurrency)
+ require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash))
+
+ newIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("new@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, newIdentityCount)
+
+ oldIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("current@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, oldIdentityCount)
+
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ hashedPassword, err := svc.HashPassword("current-password")
+ require.NoError(t, err)
+
+ user, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(hashedPassword).
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ require.NoError(t, client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("current@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "test"}).
+ Exec(ctx))
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "wrong-password")
+ require.ErrorIs(t, err, service.ErrPasswordIncorrect)
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "current@example.com", storedUser.Email)
+ require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash))
+
+ oldIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("current@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, oldIdentityCount)
+
+ newIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("new@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, newIdentityCount)
+}
+
+func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *testing.T) {
+ ctx := context.Background()
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ refreshTokenCache := newEmailBindRefreshTokenCacheStub()
+ userRepo := newEmailBindUserRepoStub(&service.User{
+ ID: 41,
+ Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
+ Username: "legacy-user",
+ PasswordHash: "old-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 4,
+ })
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-bind-email-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ }
+ emailService := service.NewEmailService(nil, cache)
+ svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil)
+
+ oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
+ ID: 41,
+ Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 4,
+ }, "")
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, 41, "new@example.com", "123456", "new-password")
+ require.NoError(t, err)
+ require.NotNil(t, updatedUser)
+
+ storedUser, err := userRepo.GetByID(ctx, 41)
+ require.NoError(t, err)
+ require.Equal(t, "new@example.com", storedUser.Email)
+ require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
+
+ _, err = svc.RefreshToken(ctx, oldTokenPair.AccessToken)
+ require.ErrorIs(t, err, service.ErrTokenRevoked)
+
+ _, err = svc.RefreshTokenPair(ctx, oldTokenPair.RefreshToken)
+ require.True(t, errors.Is(err, service.ErrTokenRevoked) || errors.Is(err, service.ErrRefreshTokenInvalid))
+}
+
+type emailBindSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *emailBindSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *emailBindSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if v, ok := s.values[key]; ok {
+ return v, nil
+ }
+ return "", service.ErrSettingNotFound
+}
+
+func (s *emailBindSettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (s *emailBindSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ out[key] = v
+ }
+ }
+ return out, nil
+}
+
+func (s *emailBindSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *emailBindSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *emailBindSettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+type emailBindCacheStub struct {
+ data *service.VerificationCodeData
+ err error
+}
+
+func (s *emailBindCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+ return s.data, nil
+}
+
+func (s *emailBindCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) DeleteVerificationCode(context.Context, string) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *emailBindCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *emailBindCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *emailBindCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
+
+type emailBindRefreshTokenCacheStub struct {
+ mu sync.Mutex
+ tokens map[string]*service.RefreshTokenData
+ userSets map[int64]map[string]struct{}
+ families map[string]map[string]struct{}
+}
+
+func newEmailBindRefreshTokenCacheStub() *emailBindRefreshTokenCacheStub {
+ return &emailBindRefreshTokenCacheStub{
+ tokens: make(map[string]*service.RefreshTokenData),
+ userSets: make(map[int64]map[string]struct{}),
+ families: make(map[string]map[string]struct{}),
+ }
+}
+
+func (s *emailBindRefreshTokenCacheStub) StoreRefreshToken(_ context.Context, tokenHash string, data *service.RefreshTokenData, _ time.Duration) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ cloned := *data
+ s.tokens[tokenHash] = &cloned
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) GetRefreshToken(_ context.Context, tokenHash string) (*service.RefreshTokenData, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ data, ok := s.tokens[tokenHash]
+ if !ok {
+ return nil, service.ErrRefreshTokenNotFound
+ }
+ cloned := *data
+ return &cloned, nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) DeleteRefreshToken(_ context.Context, tokenHash string) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ delete(s.tokens, tokenHash)
+ for _, tokenSet := range s.userSets {
+ delete(tokenSet, tokenHash)
+ }
+ for _, tokenSet := range s.families {
+ delete(tokenSet, tokenHash)
+ }
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for tokenHash := range s.userSets[userID] {
+ delete(s.tokens, tokenHash)
+ for _, tokenSet := range s.families {
+ delete(tokenSet, tokenHash)
+ }
+ }
+ delete(s.userSets, userID)
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) DeleteTokenFamily(_ context.Context, familyID string) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for tokenHash := range s.families[familyID] {
+ delete(s.tokens, tokenHash)
+ for _, tokenSet := range s.userSets {
+ delete(tokenSet, tokenHash)
+ }
+ }
+ delete(s.families, familyID)
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) AddToUserTokenSet(_ context.Context, userID int64, tokenHash string, _ time.Duration) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.userSets[userID] == nil {
+ s.userSets[userID] = make(map[string]struct{})
+ }
+ s.userSets[userID][tokenHash] = struct{}{}
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) AddToFamilyTokenSet(_ context.Context, familyID string, tokenHash string, _ time.Duration) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.families[familyID] == nil {
+ s.families[familyID] = make(map[string]struct{})
+ }
+ s.families[familyID][tokenHash] = struct{}{}
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) GetUserTokenHashes(_ context.Context, userID int64) ([]string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ tokenSet := s.userSets[userID]
+ out := make([]string, 0, len(tokenSet))
+ for tokenHash := range tokenSet {
+ out = append(out, tokenHash)
+ }
+ return out, nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) GetFamilyTokenHashes(_ context.Context, familyID string) ([]string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ tokenSet := s.families[familyID]
+ out := make([]string, 0, len(tokenSet))
+ for tokenHash := range tokenSet {
+ out = append(out, tokenHash)
+ }
+ return out, nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) IsTokenInFamily(_ context.Context, familyID string, tokenHash string) (bool, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ _, ok := s.families[familyID][tokenHash]
+ return ok, nil
+}
+
+type emailBindUserRepoStub struct {
+ mu sync.Mutex
+ usersByID map[int64]*service.User
+ usersByEmail map[string]*service.User
+}
+
+func newEmailBindUserRepoStub(user *service.User) *emailBindUserRepoStub {
+ cloned := cloneEmailBindUser(user)
+ return &emailBindUserRepoStub{
+ usersByID: map[int64]*service.User{
+ cloned.ID: cloned,
+ },
+ usersByEmail: map[string]*service.User{
+ cloned.Email: cloned,
+ },
+ }
+}
+
+func (s *emailBindUserRepoStub) Create(context.Context, *service.User) error { return nil }
+
+func (s *emailBindUserRepoStub) GetByID(_ context.Context, id int64) (*service.User, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ user, ok := s.usersByID[id]
+ if !ok {
+ return nil, service.ErrUserNotFound
+ }
+ return cloneEmailBindUser(user), nil
+}
+
+func (s *emailBindUserRepoStub) GetByEmail(_ context.Context, email string) (*service.User, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ user, ok := s.usersByEmail[email]
+ if !ok {
+ return nil, service.ErrUserNotFound
+ }
+ return cloneEmailBindUser(user), nil
+}
+
+func (s *emailBindUserRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
+ panic("unexpected GetFirstAdmin call")
+}
+
+func (s *emailBindUserRepoStub) Update(_ context.Context, user *service.User) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ existing, ok := s.usersByID[user.ID]
+ if !ok {
+ return service.ErrUserNotFound
+ }
+ delete(s.usersByEmail, existing.Email)
+ cloned := cloneEmailBindUser(user)
+ s.usersByID[user.ID] = cloned
+ s.usersByEmail[cloned.Email] = cloned
+ return nil
+}
+
+func (s *emailBindUserRepoStub) Delete(context.Context, int64) error { return nil }
+
+func (s *emailBindUserRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (s *emailBindUserRepoStub) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *emailBindUserRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
+func (s *emailBindUserRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (s *emailBindUserRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (s *emailBindUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (s *emailBindUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ return nil
+}
+
+func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+
+func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ _, ok := s.usersByEmail[email]
+ return ok, nil
+}
+
+func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailBindUserRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+func (s *emailBindUserRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+func (s *emailBindUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (s *emailBindUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
+ return nil
+}
+
+func (s *emailBindUserRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (s *emailBindUserRepoStub) EnableTotp(context.Context, int64) error { return nil }
+func (s *emailBindUserRepoStub) DisableTotp(context.Context, int64) error { return nil }
+
+func cloneEmailBindUser(user *service.User) *service.User {
+ if user == nil {
+ return nil
+ }
+ cloned := *user
+ return &cloned
+}
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
new file mode 100644
index 00000000..53048b92
--- /dev/null
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -0,0 +1,482 @@
+//go:build unit
+
+package service_test
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type authIdentityDefaultSubAssignerStub struct {
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
+}
+
+type flakyAuthIdentityDefaultSubAssignerStub struct {
+ failuresRemaining int
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *flakyAuthIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ if s.failuresRemaining > 0 {
+ s.failuresRemaining--
+ return nil, false, errors.New("temporary assign failure")
+ }
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
+}
+
+type authIdentitySettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if v, ok := s.values[key]; ok {
+ return v, nil
+ }
+ return "", service.ErrSettingNotFound
+}
+
+func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (s *authIdentitySettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ out[key] = v
+ }
+ }
+ return out, nil
+}
+
+func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+func newAuthServiceWithEnt(
+ t *testing.T,
+ settings map[string]string,
+ defaultSubAssigner service.DefaultSubscriptionAssigner,
+) (*service.AuthService, service.UserRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ provider_type TEXT NOT NULL,
+ grant_reason TEXT NOT NULL DEFAULT 'first_bind',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, provider_type, grant_reason)
+)`)
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ repo := repository.NewUserRepository(client, db)
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-auth-identity-secret",
+ ExpireHour: 1,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+ settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
+ values: settings,
+ }, cfg)
+
+ svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil)
+ return svc, repo, client
+}
+
+func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
+ ctx := context.Background()
+
+ token, user, err := svc.Register(ctx, "user@example.com", "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, user)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "email", storedUser.SignupSource)
+ require.NotNil(t, storedUser.LastLoginAt)
+ require.NotNil(t, storedUser.LastActiveAt)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("user@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+ require.NotNil(t, identity.VerifiedAt)
+}
+
+func TestAuthServiceLoginDefersLastLoginTouchUntilRecordSuccessfulLogin(t *testing.T) {
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("login@example.com").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ SetBalance(1).
+ SetConcurrency(1).
+ Save(ctx)
+ require.NoError(t, err)
+
+ old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
+ _, err = client.User.UpdateOneID(user.ID).
+ SetLastLoginAt(old).
+ SetLastActiveAt(old).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedUser.LastLoginAt)
+ require.NotNil(t, storedUser.LastActiveAt)
+ require.True(t, storedUser.LastLoginAt.Equal(old))
+ require.True(t, storedUser.LastActiveAt.Equal(old))
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("login@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("login@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+}
+
+func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
+ svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
+ ctx := context.Background()
+
+ user := &service.User{
+ Email: "record@example.com",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 1,
+ Concurrency: 1,
+ }
+ require.NoError(t, user.SetPassword("password"))
+ require.NoError(t, repo.Create(ctx, user))
+
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("record@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+}
+
+func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("legacy@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+
+ token, gotUser, err = svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err = client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceLogin_DoesNotApplyMergedEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyDefaultSubscriptions: `[{"group_id":21,"validity_days":14}]`,
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("merged-first-bind@example.com").
+ SetUsername("merged-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("bound@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(2).
+ SetConcurrency(3).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("bound@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "preexisting"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 2.0, storedUser.Balance)
+ require.Equal(t, 3, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceLogin_DoesNotRetryEmailFirstBindDefaultsForBackfilledEmailIdentity(t *testing.T) {
+ assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("retry-first-bind@example.com").
+ SetUsername("retry-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+
+ token, gotUser, err = svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err = client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func countProviderGrantRecords(
+ t *testing.T,
+ client *dbent.Client,
+ userID int64,
+ providerType string,
+ grantReason string,
+) int {
+ t.Helper()
+
+ var count int
+ rows, err := client.QueryContext(
+ context.Background(),
+ `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`,
+ userID,
+ providerType,
+ grantReason,
+ )
+ require.NoError(t, err)
+ defer rows.Close()
+ require.True(t, rows.Next())
+ require.NoError(t, rows.Scan(&count))
+ require.NoError(t, rows.Err())
+ return count
+}
diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go
deleted file mode 100644
index 0472e06c..00000000
--- a/backend/internal/service/auth_service_pending_oauth_test.go
+++ /dev/null
@@ -1,146 +0,0 @@
-//go:build unit
-
-package service
-
-import (
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/golang-jwt/jwt/v5"
- "github.com/stretchr/testify/require"
-)
-
-func newAuthServiceForPendingOAuthTest() *AuthService {
- cfg := &config.Config{
- JWT: config.JWTConfig{
- Secret: "test-secret-pending-oauth",
- ExpireHour: 1,
- },
- }
- return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
-}
-
-// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
-func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
- require.NoError(t, err)
- require.NotEmpty(t, token)
-
- email, username, err := svc.VerifyPendingOAuthToken(token)
- require.NoError(t, err)
- require.Equal(t, "user@example.com", email)
- require.Equal(t, "alice", username)
-}
-
-// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- // 签发一个普通 access token(JWTClaims,无 Purpose 字段)
- accessToken, err := svc.GenerateToken(&User{
- ID: 1,
- Email: "user@example.com",
- Role: RoleUser,
- })
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(accessToken)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: "some_other_purpose",
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: "", // 旧 token 无此字段,反序列化后为零值
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- past := time.Now().Add(-1 * time.Hour)
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: pendingOAuthPurpose,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(past),
- IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
- NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
- other := NewAuthService(nil, nil, nil, nil, &config.Config{
- JWT: config.JWTConfig{Secret: "other-secret"},
- }, nil, nil, nil, nil, nil, nil)
-
- token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
- require.NoError(t, err)
-
- svc := newAuthServiceForPendingOAuthTest()
- _, _, err = svc.VerifyPendingOAuthToken(token)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
- giant := make([]byte, maxTokenLength+1)
- for i := range giant {
- giant[i] = 'a'
- }
- _, _, err := svc.VerifyPendingOAuthToken(string(giant))
- require.ErrorIs(t, err, ErrInvalidToken)
-}
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index 103bafe7..acc44a38 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -37,7 +37,16 @@ func (s *settingRepoStub) Set(ctx context.Context, key, value string) error {
}
func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
- panic("unexpected GetMultiple call")
+ if s.err != nil {
+ return nil, s.err
+ }
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ result[key] = v
+ }
+ }
+ return result, nil
}
func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
@@ -62,6 +71,8 @@ type defaultSubscriptionAssignerStub struct {
err error
}
+type refreshTokenCacheStub struct{}
+
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
if input != nil {
s.calls = append(s.calls, *input)
@@ -72,6 +83,46 @@ func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.C
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
+func (s *refreshTokenCacheStub) StoreRefreshToken(context.Context, string, *RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) GetRefreshToken(context.Context, string) (*RefreshTokenData, error) {
+ return nil, ErrRefreshTokenNotFound
+}
+
+func (s *refreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *refreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *refreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
@@ -161,6 +212,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
nil,
nil, // promoService
nil, // defaultSubAssigner
+ nil, // affiliateService
)
}
@@ -192,7 +244,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
- _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
+ _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
@@ -204,7 +256,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
- _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
+ _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
@@ -218,7 +270,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
- _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
+ _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code")
}
@@ -322,7 +374,8 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
func TestAuthService_Register_Success(t *testing.T) {
repo := &userRepoStub{nextID: 5}
service := newAuthService(repo, map[string]string{
- SettingKeyRegistrationEnabled: "true",
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password")
@@ -469,8 +522,9 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 42}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
- SettingKeyRegistrationEnabled: "true",
- SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
service.defaultSubAssigner = assigner
@@ -484,3 +538,132 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
require.Equal(t, int64(12), assigner.calls[1].GroupID)
require.Equal(t, 7, assigner.calls[1].ValidityDays)
}
+
+func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *testing.T) {
+ repo := &userRepoStub{nextID: 52}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":91,"validity_days":3}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "12.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "7",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-defaults@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 12.5, user.Balance)
+ require.Equal(t, 7, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(11), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *testing.T) {
+ repo := &userRepoStub{nextID: 53}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "99",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "88",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-global@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 3.5, user.Balance)
+ require.Equal(t, 2, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(31), assigner.calls[0].GroupID)
+ require.Equal(t, 5, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaults(t *testing.T) {
+ repo := &userRepoStub{nextID: 54}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-merged@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 9.5, user.Balance)
+ require.Equal(t, 2, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(31), assigner.calls[0].GroupID)
+ require.Equal(t, 5, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) {
+ repo := &userRepoStub{nextID: 61}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":81,"validity_days":1}]`,
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+ service.refreshTokenCache = &refreshTokenCacheStub{}
+
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.NotNil(t, user)
+ require.Equal(t, int64(61), user.ID)
+ require.Equal(t, 21.75, user.Balance)
+ require.Equal(t, 9, user.Concurrency)
+ require.Len(t, repo.created, 1)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(22), assigner.calls[0].GroupID)
+ require.Equal(t, 14, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantAgain(t *testing.T) {
+ existing := &User{
+ ID: 88,
+ Email: "linuxdo-123@linuxdo-connect.invalid",
+ Username: "existing-linuxdo",
+ Role: RoleUser,
+ Status: StatusActive,
+ Balance: 4,
+ Concurrency: 1,
+ TokenVersion: 2,
+ }
+ repo := &userRepoStub{user: existing}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+ service.refreshTokenCache = &refreshTokenCacheStub{}
+
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.Equal(t, existing.ID, user.ID)
+ require.Equal(t, 4.0, user.Balance)
+ require.Equal(t, 1, user.Concurrency)
+ require.Empty(t, repo.created)
+ require.Empty(t, assigner.calls)
+}
diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go
index 477ba1b2..3512822f 100644
--- a/backend/internal/service/auth_service_turnstile_register_test.go
+++ b/backend/internal/service/auth_service_turnstile_register_test.go
@@ -54,6 +54,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
nil, // emailQueueService
nil, // promoService
nil, // defaultSubAssigner
+ nil, // affiliateService
)
}
diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go
index f2ad0a3d..050db55b 100644
--- a/backend/internal/service/billing_cache_service.go
+++ b/backend/internal/service/billing_cache_service.go
@@ -20,6 +20,9 @@ import (
var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
+ // RPM 超限错误。gateway_handler 负责映射为 HTTP 429。
+ ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded")
+ ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
@@ -87,6 +90,8 @@ type BillingCacheService struct {
userRepo UserRepository
subRepo UserSubscriptionRepository
apiKeyRateLimitLoader apiKeyRateLimitLoader
+ userRPMCache UserRPMCache
+ userGroupRateRepo UserGroupRateRepository
cfg *config.Config
circuitBreaker *billingCircuitBreaker
@@ -104,12 +109,22 @@ type BillingCacheService struct {
}
// NewBillingCacheService 创建计费缓存服务
-func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
+func NewBillingCacheService(
+ cache BillingCache,
+ userRepo UserRepository,
+ subRepo UserSubscriptionRepository,
+ apiKeyRepo APIKeyRepository,
+ userRPMCache UserRPMCache,
+ userGroupRateRepo UserGroupRateRepository,
+ cfg *config.Config,
+) *BillingCacheService {
svc := &BillingCacheService{
cache: cache,
userRepo: userRepo,
subRepo: subRepo,
apiKeyRateLimitLoader: apiKeyRepo,
+ userRPMCache: userRPMCache,
+ userGroupRateRepo: userGroupRateRepo,
cfg: cfg,
}
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
@@ -493,6 +508,18 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
return nil
}
+// InvalidateAPIKeyRateLimit invalidates the Redis rate-limit usage cache for an API key.
+func (s *BillingCacheService) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
+ if s.cache == nil {
+ return nil
+ }
+ if err := s.cache.InvalidateAPIKeyRateLimit(ctx, keyID); err != nil {
+ logger.LegacyPrintf("service.billing_cache", "Warning: invalidate api key rate limit cache failed for key %d: %v", keyID, err)
+ return err
+ }
+ return nil
+}
+
// ============================================
// API Key 限速缓存方法
// ============================================
@@ -664,6 +691,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user
}
}
+ // RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。
+ if err := s.checkRPM(ctx, user, group); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝:
+//
+// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。
+// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。
+// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。
+// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。
+//
+// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。
+// Redis 故障一律 fail-open(打 warning,不阻塞业务)。
+func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error {
+ if s == nil || s.userRPMCache == nil || user == nil {
+ return nil
+ }
+
+ // ── 第一层:分组级检查(override 或 group.rpm_limit) ──
+ if group != nil {
+ // 解析 override:优先从 auth cache snapshot,nil 时回退 DB。
+ var override *int
+ if user.UserGroupRPMOverride != nil {
+ override = user.UserGroupRPMOverride
+ } else if s.userGroupRateRepo != nil {
+ dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID)
+ if err != nil {
+ logger.LegacyPrintf(
+ "service.billing_cache",
+ "Warning: rpm override lookup failed for user=%d group=%d: %v",
+ user.ID, group.ID, err,
+ )
+ } else {
+ override = dbOverride
+ }
+ }
+
+ if override != nil {
+ // override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。
+ if *override > 0 {
+ count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
+ if incErr != nil {
+ logger.LegacyPrintf(
+ "service.billing_cache",
+ "Warning: rpm increment (override) failed for user=%d group=%d: %v",
+ user.ID, group.ID, incErr,
+ )
+ // fail-open
+ } else if count > *override {
+ return ErrGroupRPMExceeded
+ }
+ }
+ // override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。
+ } else if group.RPMLimit > 0 {
+ // 无 override,检查 group.rpm_limit。
+ count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID)
+ if err != nil {
+ logger.LegacyPrintf(
+ "service.billing_cache",
+ "Warning: rpm increment (group) failed for user=%d group=%d: %v",
+ user.ID, group.ID, err,
+ )
+ // fail-open
+ } else if count > group.RPMLimit {
+ return ErrGroupRPMExceeded
+ }
+ }
+ }
+
+ // ── 第二层:用户级全局硬上限(始终生效) ──
+ if user.RPMLimit > 0 {
+ count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID)
+ if err != nil {
+ logger.LegacyPrintf(
+ "service.billing_cache",
+ "Warning: rpm increment (user) failed for user=%d: %v",
+ user.ID, err,
+ )
+ return nil // fail-open
+ }
+ if count > user.RPMLimit {
+ return ErrUserRPMExceeded
+ }
+ }
+
return nil
}
diff --git a/backend/internal/service/billing_cache_service_rpm_test.go b/backend/internal/service/billing_cache_service_rpm_test.go
new file mode 100644
index 00000000..de66136f
--- /dev/null
+++ b/backend/internal/service/billing_cache_service_rpm_test.go
@@ -0,0 +1,253 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "sync/atomic"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。
+type userRPMCacheStub struct {
+ userGroupCalls int32
+ userCalls int32
+
+ userGroupCounts []int // 依次返回的计数值
+ userGroupErr error
+ userCounts []int
+ userErr error
+}
+
+func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
+ idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1
+ if s.userGroupErr != nil {
+ return 0, s.userGroupErr
+ }
+ if idx < len(s.userGroupCounts) {
+ return s.userGroupCounts[idx], nil
+ }
+ return 1, nil
+}
+
+func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) {
+ idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1
+ if s.userErr != nil {
+ return 0, s.userErr
+ }
+ if idx < len(s.userCounts) {
+ return s.userCounts[idx], nil
+ }
+ return 1, nil
+}
+
+func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) {
+ return 0, nil
+}
+
+func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) {
+ return 0, nil
+}
+
+// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。
+type rpmOverrideRepoStub struct {
+ UserGroupRateRepository
+
+ override *int
+ err error
+ calls int32
+}
+
+func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
+ atomic.AddInt32(&s.calls, 1)
+ if s.err != nil {
+ return nil, s.err
+ }
+ return s.override, nil
+}
+
+func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService {
+ t.Helper()
+ // 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。
+ // 我们只直接测 checkRPM。
+ svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{})
+ t.Cleanup(svc.Stop)
+ return svc
+}
+
+func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) {
+ override := 2
+ // user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰)
+ cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}}
+ repo := &rpmOverrideRepoStub{override: &override}
+ svc := newBillingServiceForRPM(t, cache, repo)
+
+ user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试
+ group := &Group{ID: 10, RPMLimit: 100}
+
+ require.NoError(t, svc.checkRPM(context.Background(), user, group))
+ require.NoError(t, svc.checkRPM(context.Background(), user, group))
+ require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded)
+
+ require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数")
+ // 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user
+ require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用")
+ require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls))
+}
+
+func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) {
+ override := 100 // override 很高
+ // user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3
+ cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
+ repo := &rpmOverrideRepoStub{override: &override}
+ svc := newBillingServiceForRPM(t, cache, repo)
+
+ user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2,应覆盖 override=100
+ group := &Group{ID: 10, RPMLimit: 100}
+
+ require.NoError(t, svc.checkRPM(context.Background(), user, group))
+ require.NoError(t, svc.checkRPM(context.Background(), user, group))
+ require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override")
+}
+
+func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) {
+ zero := 0
+ // user 计数: 依次返回 1..6
+ cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}}
+ repo := &rpmOverrideRepoStub{override: &zero}
+ svc := newBillingServiceForRPM(t, cache, repo)
+
+ user := &User{ID: 1, RPMLimit: 5}
+ group := &Group{ID: 10, RPMLimit: 100}
+
+ // override=0 跳过分组计数,但 user.RPMLimit=5 仍生效
+ for i := 0; i < 5; i++ {
+ require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1)
+ }
+ require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded,
+ "override=0 跳过分组但 user 全局上限仍应生效")
+ require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器")
+ require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用")
+}
+
+func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) {
+ zero := 0
+ cache := &userRPMCacheStub{}
+ repo := &rpmOverrideRepoStub{override: &zero}
+ svc := newBillingServiceForRPM(t, cache, repo)
+
+ user := &User{ID: 1, RPMLimit: 0} // user 也不限
+ group := &Group{ID: 10, RPMLimit: 100}
+
+ for i := 0; i < 50; i++ {
+ require.NoError(t, svc.checkRPM(context.Background(), user, group))
+ }
+ require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数")
+ require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数")
+}
+
+func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) {
+ // user-group 计数: 5, 6;user 计数: 默认 1(不干扰)
+ cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}}
+ repo := &rpmOverrideRepoStub{override: nil}
+ svc := newBillingServiceForRPM(t, cache, repo)
+
+ user := &User{ID: 1, RPMLimit: 999} // 全局上限很高,group 先超
+ group := &Group{ID: 10, RPMLimit: 5}
+
+ require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超
+ require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5
+
+ require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls))
+ // 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user
+ require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查;group 超时直接返回")
+}
+
+func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) {
+ cache := &userRPMCacheStub{userGroupCounts: []int{3}}
+ repo := &rpmOverrideRepoStub{err: errors.New("db down")}
+ svc := newBillingServiceForRPM(t, cache, repo)
+
+ user := &User{ID: 1, RPMLimit: 0}
+ group := &Group{ID: 10, RPMLimit: 10}
+
+ // override 查询失败后应继续尝试 group 分支(不直接拒绝)
+ require.NoError(t, svc.checkRPM(context.Background(), user, group))
+ require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
+ require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls))
+}
+
+func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) {
+ cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
+ repo := &rpmOverrideRepoStub{override: nil}
+ svc := newBillingServiceForRPM(t, cache, repo)
+
+ user := &User{ID: 1, RPMLimit: 2}
+ group := &Group{ID: 10, RPMLimit: 0} // 分组未设限
+
+ require.NoError(t, svc.checkRPM(context.Background(), user, group))
+ require.NoError(t, svc.checkRPM(context.Background(), user, group))
+ require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded)
+
+ require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键")
+ require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
+}
+
+func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) {
+ cache := &userRPMCacheStub{}
+ repo := &rpmOverrideRepoStub{override: nil}
+ svc := newBillingServiceForRPM(t, cache, repo)
+
+ user := &User{ID: 1, RPMLimit: 0}
+ group := &Group{ID: 10, RPMLimit: 0}
+
+ for i := 0; i < 10; i++ {
+ require.NoError(t, svc.checkRPM(context.Background(), user, group))
+ }
+ require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
+ require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
+}
+
+func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) {
+ cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")}
+ repo := &rpmOverrideRepoStub{override: nil}
+ svc := newBillingServiceForRPM(t, cache, repo)
+
+ user := &User{ID: 1, RPMLimit: 0}
+ group := &Group{ID: 10, RPMLimit: 5}
+
+ // Redis 故障时应 fail-open,不拒绝请求
+ require.NoError(t, svc.checkRPM(context.Background(), user, group))
+ require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls))
+}
+
+func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) {
+ cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}}
+ repo := &rpmOverrideRepoStub{}
+ svc := newBillingServiceForRPM(t, cache, repo)
+
+ user := &User{ID: 1, RPMLimit: 2}
+
+ // 无 group(纯用户级限流场景),不应查询 rpm_override。
+ require.NoError(t, svc.checkRPM(context.Background(), user, nil))
+ require.NoError(t, svc.checkRPM(context.Background(), user, nil))
+ require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded)
+
+ require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override")
+ require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls))
+}
+
+func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) {
+ cache := &userRPMCacheStub{}
+ repo := &rpmOverrideRepoStub{}
+ svc := newBillingServiceForRPM(t, cache, repo)
+
+ require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10}))
+ require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls))
+ require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls))
+ require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls))
+}
diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go
index 4a8b8f03..962becf0 100644
--- a/backend/internal/service/billing_cache_service_singleflight_test.go
+++ b/backend/internal/service/billing_cache_service_singleflight_test.go
@@ -86,13 +86,21 @@ func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User,
return &User{ID: id, Balance: s.balance}, nil
}
+func (s *balanceLoadUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (s *balanceLoadUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
+ return nil
+}
+
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
cache := &billingCacheMissStub{}
userRepo := &balanceLoadUserRepoStub{
delay: 80 * time.Millisecond,
balance: 12.34,
}
- svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
+ svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
const goroutines = 16
diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go
index 7d7045e2..849e24b8 100644
--- a/backend/internal/service/billing_cache_service_test.go
+++ b/backend/internal/service/billing_cache_service_test.go
@@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context,
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{}
- svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
+ svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
start := time.Now()
@@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{}
- svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
+ svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{})
svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go
index 32a54cbe..392b3e0b 100644
--- a/backend/internal/service/billing_service.go
+++ b/backend/internal/service/billing_service.go
@@ -203,17 +203,6 @@ func (s *BillingService) initFallbackPricing() {
SupportsCacheBreakdown: false,
}
- // OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
- s.fallbackPrices["gpt-5.1"] = &ModelPricing{
- InputPricePerToken: 1.25e-6, // $1.25 per MTok
- InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok
- OutputPricePerToken: 10e-6, // $10 per MTok
- OutputPricePerTokenPriority: 20e-6, // $20 per MTok
- CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
- CacheReadPricePerToken: 0.125e-6,
- CacheReadPricePerTokenPriority: 0.25e-6,
- SupportsCacheBreakdown: false,
- }
// OpenAI GPT-5.4(业务指定价格)
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
InputPricePerToken: 2.5e-6, // $2.5 per MTok
@@ -228,18 +217,15 @@ func (s *BillingService) initFallbackPricing() {
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
}
+ // GPT-5.5 暂无独立定价,回退到 GPT-5.4
+ s.fallbackPrices["gpt-5.5"] = s.fallbackPrices["gpt-5.4"]
+
s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{
InputPricePerToken: 7.5e-7,
OutputPricePerToken: 4.5e-6,
CacheReadPricePerToken: 7.5e-8,
SupportsCacheBreakdown: false,
}
- s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{
- InputPricePerToken: 2e-7,
- OutputPricePerToken: 1.25e-6,
- CacheReadPricePerToken: 2e-8,
- SupportsCacheBreakdown: false,
- }
// OpenAI GPT-5.2(本地兜底)
s.fallbackPrices["gpt-5.2"] = &ModelPricing{
InputPricePerToken: 1.75e-6,
@@ -251,8 +237,8 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerTokenPriority: 0.35e-6,
SupportsCacheBreakdown: false,
}
- // Codex 族兜底统一按 GPT-5.1 Codex 价格计费
- s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
+ // Codex 族兜底统一按 GPT-5.3 Codex 价格计费
+ s.fallbackPrices["gpt-5.3-codex"] = &ModelPricing{
InputPricePerToken: 1.5e-6, // $1.5 per MTok
InputPricePerTokenPriority: 3e-6, // $3 per MTok
OutputPricePerToken: 12e-6, // $12 per MTok
@@ -262,17 +248,6 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerTokenPriority: 0.3e-6,
SupportsCacheBreakdown: false,
}
- s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{
- InputPricePerToken: 1.75e-6,
- InputPricePerTokenPriority: 3.5e-6,
- OutputPricePerToken: 14e-6,
- OutputPricePerTokenPriority: 28e-6,
- CacheCreationPricePerToken: 1.75e-6,
- CacheReadPricePerToken: 0.175e-6,
- CacheReadPricePerTokenPriority: 0.35e-6,
- SupportsCacheBreakdown: false,
- }
- s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
}
// getFallbackPricing 根据模型系列获取回退价格
@@ -316,22 +291,16 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
normalized := normalizeCodexModel(modelLower)
switch normalized {
+ case "gpt-5.5":
+ return s.fallbackPrices["gpt-5.5"]
case "gpt-5.4-mini":
return s.fallbackPrices["gpt-5.4-mini"]
- case "gpt-5.4-nano":
- return s.fallbackPrices["gpt-5.4-nano"]
case "gpt-5.4":
return s.fallbackPrices["gpt-5.4"]
case "gpt-5.2":
return s.fallbackPrices["gpt-5.2"]
- case "gpt-5.2-codex":
- return s.fallbackPrices["gpt-5.2-codex"]
- case "gpt-5.3-codex":
+ case "gpt-5.3-codex", "gpt-5.3-codex-spark":
return s.fallbackPrices["gpt-5.3-codex"]
- case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
- return s.fallbackPrices["gpt-5.1-codex"]
- case "gpt-5.1":
- return s.fallbackPrices["gpt-5.1"]
}
}
@@ -448,8 +417,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
})
}
- if input.RateMultiplier <= 0 {
- input.RateMultiplier = 1.0
+ // 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。
+ if input.RateMultiplier < 0 {
+ input.RateMultiplier = 0
}
var breakdown *CostBreakdown
@@ -493,8 +463,9 @@ func (s *BillingService) computeTokenBreakdown(
rateMultiplier float64, serviceTier string,
applyLongCtx bool,
) *CostBreakdown {
- if rateMultiplier <= 0 {
- rateMultiplier = 1.0
+ // 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。
+ if rateMultiplier < 0 {
+ rateMultiplier = 0
}
inputPrice := pricing.InputPricePerToken
@@ -665,8 +636,14 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens
}
func isOpenAIGPT54Model(model string) bool {
- normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model)))
- return normalized == "gpt-5.4"
+ trimmed := strings.TrimSpace(strings.ToLower(model))
+ // 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel
+ // 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。
+ if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
+ return false
+ }
+ normalized := normalizeCodexModel(trimmed)
+ return normalized == "gpt-5.4" || normalized == "gpt-5.5"
}
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
@@ -831,9 +808,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
// 计算总费用
totalCost := unitPrice * float64(imageCount)
- // 应用倍率
- if rateMultiplier <= 0 {
- rateMultiplier = 1.0
+ // 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣)
+ if rateMultiplier < 0 {
+ rateMultiplier = 0
}
actualCost := totalCost * rateMultiplier
diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go
index fa90f6bb..8d3ca987 100644
--- a/backend/internal/service/billing_service_image_test.go
+++ b/backend/internal/service/billing_service_image_test.go
@@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) {
require.Equal(t, 0.0, cost.ActualCost)
}
-// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0
+// TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费
+// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
svc := &BillingService{}
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
- require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
+ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
}
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
diff --git a/backend/internal/service/billing_service_rate_multiplier_test.go b/backend/internal/service/billing_service_rate_multiplier_test.go
new file mode 100644
index 00000000..83788196
--- /dev/null
+++ b/backend/internal/service/billing_service_rate_multiplier_test.go
@@ -0,0 +1,63 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被
+// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。
+func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
+ svc := newTestBillingService()
+ tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
+
+ tests := []struct {
+ name string
+ multiplier float64
+ wantRatio float64 // ActualCost / TotalCost
+ }{
+ {"negative clamped to 0", -1.5, 0},
+ {"zero passes through as 0 (defense in depth)", 0, 0},
+ {"positive 2x applied", 2.0, 2.0},
+ {"positive 0.5x applied", 0.5, 0.5},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier)
+ require.NoError(t, err)
+ require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero")
+ require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
+ })
+ }
+}
+
+// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径
+// 同样遵循"负数 → 0"语义。
+func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
+ svc := newTestBillingService()
+ price := 0.04
+ cfg := &ImagePriceConfig{Price1K: &price}
+
+ tests := []struct {
+ name string
+ multiplier float64
+ wantRatio float64
+ }{
+ {"negative clamped to 0", -0.5, 0},
+ {"zero passes through", 0, 0},
+ {"positive 3x applied", 3.0, 3.0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier)
+ require.NotNil(t, cost)
+ require.Greater(t, cost.TotalCost, 0.0)
+ require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
+ })
+ }
+}
diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go
index 2cf134e2..222abd69 100644
--- a/backend/internal/service/billing_service_test.go
+++ b/backend/internal/service/billing_service_test.go
@@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) {
require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10)
}
-func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) {
- svc := newTestBillingService()
-
- tokens := UsageTokens{InputTokens: 1000}
-
- costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0)
- require.NoError(t, err)
-
- costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
-}
-
-func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) {
- svc := newTestBillingService()
-
- tokens := UsageTokens{InputTokens: 1000}
-
- costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0)
- require.NoError(t, err)
-
- costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
-}
-
func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) {
svc := newTestBillingService()
@@ -151,15 +123,6 @@ func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) {
require.Contains(t, err.Error(), "pricing not found")
}
-func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) {
- svc := newTestBillingService()
-
- pricing, err := svc.GetModelPricing("gpt-5.1")
- require.NoError(t, err)
- require.NotNil(t, pricing)
- require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12)
-}
-
func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
svc := newTestBillingService()
@@ -186,18 +149,6 @@ func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) {
require.Zero(t, pricing.LongContextInputThreshold)
}
-func TestGetModelPricing_OpenAIGPT54NanoFallback(t *testing.T) {
- svc := newTestBillingService()
-
- pricing, err := svc.GetModelPricing("gpt-5.4-nano")
- require.NoError(t, err)
- require.NotNil(t, pricing)
- require.InDelta(t, 2e-7, pricing.InputPricePerToken, 1e-12)
- require.InDelta(t, 1.25e-6, pricing.OutputPricePerToken, 1e-12)
- require.InDelta(t, 2e-8, pricing.CacheReadPricePerToken, 1e-12)
- require.Zero(t, pricing.LongContextInputThreshold)
-}
-
func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) {
svc := newTestBillingService()
@@ -232,13 +183,13 @@ func TestGetFallbackPricing_FamilyMatching(t *testing.T) {
{name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6},
{name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6},
{name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true},
- {name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6},
{name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6},
{name: "openai gpt5.4 mini", model: "gpt-5.4-mini", expectedInput: 7.5e-7},
- {name: "openai gpt5.4 nano", model: "gpt-5.4-nano", expectedInput: 2e-7},
{name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6},
- {name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6},
- {name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6},
+ {name: "openai gpt5.3 codex spark", model: "gpt-5.3-codex-spark", expectedInput: 1.5e-6},
+ {name: "openai legacy gpt5.1 falls back to gpt5.4", model: "gpt-5.1", expectedInput: 2.5e-6},
+ {name: "openai legacy gpt5.1 codex falls back to gpt5.3 codex", model: "gpt-5.1-codex", expectedInput: 1.5e-6},
+ {name: "openai legacy codex mini latest falls back to gpt5.3 codex", model: "codex-mini-latest", expectedInput: 1.5e-6},
{name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true},
{name: "non supported family", model: "qwen-max", expectNilPricing: true},
}
diff --git a/backend/internal/service/billing_service_unified_test.go b/backend/internal/service/billing_service_unified_test.go
index 694c3384..e6a92d1a 100644
--- a/backend/internal/service/billing_service_unified_test.go
+++ b/backend/internal/service/billing_service_unified_test.go
@@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) {
require.Equal(t, string(BillingModeImage), cost.BillingMode)
}
-func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) {
+// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为:
+// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。
+func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
- costZero, err := bs.CalculateCostUnified(CostInput{
+ cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
- RateMultiplier: 0, // should default to 1.0
+ RateMultiplier: 0,
Resolver: resolver,
})
require.NoError(t, err)
-
- costOne, err := bs.CalculateCostUnified(CostInput{
- Ctx: context.Background(),
- Model: "claude-sonnet-4",
- Tokens: tokens,
- RateMultiplier: 1.0,
- Resolver: resolver,
- })
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
+ require.Greater(t, cost.TotalCost, 0.0)
+ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
}
-func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) {
+// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为:
+// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。
+func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000}
- costNeg, err := bs.CalculateCostUnified(CostInput{
+ cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
@@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T)
Resolver: resolver,
})
require.NoError(t, err)
-
- costOne, err := bs.CalculateCostUnified(CostInput{
- Ctx: context.Background(),
- Model: "claude-sonnet-4",
- Tokens: tokens,
- RateMultiplier: 1.0,
- Resolver: resolver,
- })
- require.NoError(t, err)
-
- require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
+ require.Greater(t, cost.TotalCost, 0.0)
+ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
}
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go
index 93beb972..158bf8a3 100644
--- a/backend/internal/service/channel.go
+++ b/backend/internal/service/channel.go
@@ -111,6 +111,18 @@ func (c *Channel) IsActive() bool {
return c.Status == StatusActive
}
+// normalizeBillingModelSource 若 BillingModelSource 为空则回填默认值 ChannelMapped。
+// 作为 *Channel 的实体方法集中管理默认值,service 层只需在 Channel 进入内存
+// (缓存装填、repo 读出)时调用一次,下游读路径就无需重复兜底。
+func (c *Channel) normalizeBillingModelSource() {
+ if c == nil {
+ return
+ }
+ if c.BillingModelSource == "" {
+ c.BillingModelSource = BillingModelSourceChannelMapped
+ }
+}
+
// GetModelPricing 根据模型名查找渠道定价,未找到返回 nil。
// 精确匹配,大小写不敏感。返回值拷贝,不污染缓存。
func (c *Channel) GetModelPricing(model string) *ChannelModelPricing {
@@ -345,3 +357,209 @@ type ChannelUsageFields struct {
BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}
+
+// SupportedModel 渠道的一个支持模型条目(无通配符、可直接展示给用户)
+type SupportedModel struct {
+ Name string // 用户侧模型名
+ Platform string // 所属平台
+ Pricing *ChannelModelPricing // 定价详情(nil 表示未配置定价)
+}
+
+// wildcardSuffix 是模型模式中的通配符后缀标记(仅支持尾部匹配)。
+const wildcardSuffix = "*"
+
+// splitWildcardSuffix 将模型模式拆分为 (prefix, isWildcard)。
+//
+// "claude-opus-*" → ("claude-opus-", true)
+// "claude-opus-4" → ("claude-opus-4", false)
+// "*" → ("", true)
+//
+// 注意:返回的 prefix 保持原始大小写,由调用方按需 ToLower。
+func splitWildcardSuffix(pattern string) (prefix string, isWildcard bool) {
+ if strings.HasSuffix(pattern, wildcardSuffix) {
+ return strings.TrimSuffix(pattern, wildcardSuffix), true
+ }
+ return pattern, false
+}
+
+// GetModelPricingByPlatform 在指定平台下查找精确模型的定价,未找到返回 nil。
+// 与 GetModelPricing 的区别:按 Platform 隔离,避免跨平台同名模型误匹配。
+func (c *Channel) GetModelPricingByPlatform(platform, model string) *ChannelModelPricing {
+ if c == nil {
+ return nil
+ }
+ modelLower := strings.ToLower(model)
+ for i := range c.ModelPricing {
+ if c.ModelPricing[i].Platform != platform {
+ continue
+ }
+ for _, m := range c.ModelPricing[i].Models {
+ if strings.ToLower(m) == modelLower {
+ cp := c.ModelPricing[i].Clone()
+ return &cp
+ }
+ }
+ }
+ return nil
+}
+
+// platformPricingIndex 是单个平台下定价信息的复合索引。
+// 一次扫描即可同时支持精确查找(exact 分支)与有序遍历(wildcard 分支),
+// 避免 SupportedModels 对每个平台重复扫描定价列表。
+//
+// byLower 与 names/originalCase 共享同一套去重规则:以 lower-case 模型名为 key,
+// 首个命中保留其原始大小写。names 维持按定价行扫描顺序的稳定迭代。
+type platformPricingIndex struct {
+ byLower map[string]*ChannelModelPricing // lowercased model name → pricing (Clone'd)
+ originalCase map[string]string // lowercased model name → original-case model name
+ names []string // priced model names in their ORIGINAL case, insertion-ordered, deduped case-insensitively (first wins)
+}
+
+// buildPricingIndex 对渠道的定价列表做一次扫描,按 platform 聚合为查找索引。
+// 索引值是定价条目的 Clone 指针,调用方可安全按需返回副本而不污染缓存。
+// 通配符后缀条目(如 "claude-*")不被索引(它们是模式,不是具体模型名)。
+// 同一平台中以大小写不敏感方式去重,先出现者保留原始大小写。
+func buildPricingIndex(pricings []ChannelModelPricing) map[string]*platformPricingIndex {
+ idx := make(map[string]*platformPricingIndex)
+ for i := range pricings {
+ p := pricings[i]
+ pidx, ok := idx[p.Platform]
+ if !ok {
+ pidx = &platformPricingIndex{
+ byLower: make(map[string]*ChannelModelPricing),
+ originalCase: make(map[string]string),
+ names: make([]string, 0),
+ }
+ idx[p.Platform] = pidx
+ }
+ for _, m := range p.Models {
+ if _, wild := splitWildcardSuffix(m); wild {
+ continue
+ }
+ lower := strings.ToLower(m)
+ if _, exists := pidx.byLower[lower]; exists {
+ continue // 首个命中胜出(case-insensitive 去重后第一个定价 / 第一个原始大小写)
+ }
+ cp := pricings[i].Clone()
+ pidx.byLower[lower] = &cp
+ pidx.originalCase[lower] = m
+ pidx.names = append(pidx.names, m)
+ }
+ }
+ return idx
+}
+
+// SupportedModels 计算渠道的支持模型列表,结果保证不含通配符。
+//
+// 算法(mapping ∪ pricing 并联):
+//
+// - Pass A(mapping):遍历 ModelMapping
+// - 精确 src → target:显示名 = src(用户视角),定价用 target 在同 platform 定价里查
+// (mapping 改写后实际计费的是 target;这是用户感知的"实际花费")。
+// target 为空或为通配符时退化为按 src 自查。
+// - 通配符 src(如 "claude-3-*"):用同 platform 定价里前缀匹配的模型作为候选展开,
+// 每个候选用自身定价(通配符场景一般是 passthrough,target 通常也是通配符)。
+// - "*" 单独 mapping key 走通配符分支(前缀为空 → 全展开)。
+// - Pass B(pricing-only):遍历 ModelPricing 中所有非通配符模型,对未在 Pass A 添加过的
+// 补齐——显示名 = 定价模型名,定价 = 自身(这是关键修复:定价存在即代表渠道支持该模型,
+// 即使没配映射)。
+//
+// 显示名命中定价时使用**定价的原始大小写**(定价是模型身份的事实来源)。
+// 按 (Platform, Name) 稳定排序,按 (Platform, lowercase(Name)) 去重,先到者胜出。
+//
+// 注意:定价仅在 channel.ModelPricing 内查找——全局 LiteLLM 回落由调用方
+// (`ChannelService.ListAvailable`)在合成展示数据时叠加。
+func (c *Channel) SupportedModels() []SupportedModel {
+ if c == nil {
+ return nil
+ }
+ if len(c.ModelMapping) == 0 && len(c.ModelPricing) == 0 {
+ return nil
+ }
+
+ idx := buildPricingIndex(c.ModelPricing)
+
+ type dedupKey struct {
+ platform string
+ name string
+ }
+ seen := make(map[dedupKey]struct{})
+ result := make([]SupportedModel, 0)
+
+ // lookup 在 platform pricing index 中按精确名查定价,命中时返回定价大小写。
+ lookup := func(pidx *platformPricingIndex, name string) (display string, pricing *ChannelModelPricing) {
+ if pidx == nil || name == "" {
+ return name, nil
+ }
+ lower := strings.ToLower(name)
+ if p, ok := pidx.byLower[lower]; ok {
+ return pidx.originalCase[lower], p
+ }
+ return name, nil
+ }
+
+ add := func(platform, displayName string, pricing *ChannelModelPricing) {
+ key := dedupKey{platform: platform, name: strings.ToLower(displayName)}
+ if _, ok := seen[key]; ok {
+ return
+ }
+ seen[key] = struct{}{}
+ result = append(result, SupportedModel{
+ Name: displayName,
+ Platform: platform,
+ Pricing: pricing,
+ })
+ }
+
+ // Pass A:从 mapping 展开
+ for platform, mapping := range c.ModelMapping {
+ if len(mapping) == 0 {
+ continue
+ }
+ pidx := idx[platform]
+ for src, target := range mapping {
+ prefix, isWild := splitWildcardSuffix(src)
+ if isWild {
+ if pidx == nil {
+ continue
+ }
+ prefixLower := strings.ToLower(prefix)
+ for _, candidate := range pidx.names {
+ if strings.HasPrefix(strings.ToLower(candidate), prefixLower) {
+ display, pricing := lookup(pidx, candidate)
+ add(platform, display, pricing)
+ }
+ }
+ continue
+ }
+ // 精确 mapping:定价按 target 查;target 缺失/通配则退化按 src 查
+ pricingKey := target
+ if pricingKey == "" {
+ pricingKey = src
+ }
+ if _, targetWild := splitWildcardSuffix(pricingKey); targetWild {
+ pricingKey = src
+ }
+ _, pricing := lookup(pidx, pricingKey)
+ // 显示名优先用 src 在定价里的原始大小写(若 src 本身是个定价模型名)
+ displayName, _ := lookup(pidx, src)
+ add(platform, displayName, pricing)
+ }
+ }
+
+ // Pass B:从 pricing 补齐 mapping 未覆盖的具体模型(修复"定价存在但没配映射 → 不显示")
+ for platform, pidx := range idx {
+ for _, name := range pidx.names {
+ display, pricing := lookup(pidx, name)
+ add(platform, display, pricing)
+ }
+ }
+
+ sort.SliceStable(result, func(i, j int) bool {
+ if result[i].Platform != result[j].Platform {
+ return result[i].Platform < result[j].Platform
+ }
+ return result[i].Name < result[j].Name
+ })
+ return result
+}
diff --git a/backend/internal/service/channel_available.go b/backend/internal/service/channel_available.go
new file mode 100644
index 00000000..815730e3
--- /dev/null
+++ b/backend/internal/service/channel_available.go
@@ -0,0 +1,149 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "strings"
+)
+
+// AvailableGroupRef 渠道视图中关联分组的简要信息。
+//
+// 用户侧「可用渠道」页面据此展示:专属分组 vs 公开分组(IsExclusive)、
+// 订阅 vs 标准(SubscriptionType)、默认倍率(RateMultiplier)。用户专属倍率
+// 不在这里暴露,前端自己通过 /groups/rates 拉取,和 API 密钥页面保持一致。
+type AvailableGroupRef struct {
+ ID int64
+ Name string
+ Platform string
+ SubscriptionType string
+ RateMultiplier float64
+ IsExclusive bool
+}
+
+// AvailableChannel 可用渠道视图:用于「可用渠道」页面展示渠道基础信息 +
+// 关联的分组 + 推导出的支持模型列表(无通配符)。
+type AvailableChannel struct {
+ ID int64
+ Name string
+ Description string
+ Status string
+ BillingModelSource string
+ RestrictModels bool
+ Groups []AvailableGroupRef
+ SupportedModels []SupportedModel
+}
+
+// ListAvailable 返回所有渠道的可用视图:每个渠道附带关联分组信息与支持模型列表。
+//
+// 支持模型通过 (*Channel).SupportedModels() 计算(mapping ∪ pricing 并联)。
+// 对于渠道未配置定价的模型,进一步用 PricingService 的全局 LiteLLM 数据合成
+// 一份展示用定价,让用户看到默认价格而非"未配置"。
+//
+// 关联分组信息通过 groupRepo.ListActive 查询后按 ID 映射;渠道 GroupIDs 中未在活跃列表中
+// 的分组(已停用或删除)会被忽略。
+//
+// 前置条件:s.groupRepo 必须非 nil(由 wire DI 保证)。直接 nil-deref 用于 fail-fast,
+// 避免静默掩盖注入缺失。
+func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, error) {
+ channels, err := s.repo.ListAll(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list channels: %w", err)
+ }
+
+ groups, err := s.groupRepo.ListActive(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list active groups: %w", err)
+ }
+ groupByID := make(map[int64]AvailableGroupRef, len(groups))
+ for i := range groups {
+ g := groups[i]
+ groupByID[g.ID] = AvailableGroupRef{
+ ID: g.ID,
+ Name: g.Name,
+ Platform: g.Platform,
+ SubscriptionType: g.SubscriptionType,
+ RateMultiplier: g.RateMultiplier,
+ IsExclusive: g.IsExclusive,
+ }
+ }
+
+ out := make([]AvailableChannel, 0, len(channels))
+ for i := range channels {
+ ch := &channels[i]
+ groups := make([]AvailableGroupRef, 0, len(ch.GroupIDs))
+ for _, gid := range ch.GroupIDs {
+ if ref, ok := groupByID[gid]; ok {
+ groups = append(groups, ref)
+ }
+ }
+ sort.SliceStable(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name })
+
+ ch.normalizeBillingModelSource()
+
+ supported := ch.SupportedModels()
+ s.fillGlobalPricingFallback(supported)
+
+ out = append(out, AvailableChannel{
+ ID: ch.ID,
+ Name: ch.Name,
+ Description: ch.Description,
+ Status: ch.Status,
+ BillingModelSource: ch.BillingModelSource,
+ RestrictModels: ch.RestrictModels,
+ Groups: groups,
+ SupportedModels: supported,
+ })
+ }
+
+ sort.SliceStable(out, func(i, j int) bool {
+ return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name)
+ })
+ return out, nil
+}
+
+// fillGlobalPricingFallback 对未命中渠道定价的支持模型,从全局 LiteLLM 数据合成一份
+// 展示用定价(按 token 计费)。仅用于「可用渠道」展示,不影响真实计费链路。
+//
+// 当 s.pricingService 为 nil(测试场景),跳过回落。
+func (s *ChannelService) fillGlobalPricingFallback(models []SupportedModel) {
+ if s.pricingService == nil {
+ return
+ }
+ for i := range models {
+ if models[i].Pricing != nil {
+ continue
+ }
+ lp := s.pricingService.GetModelPricing(models[i].Name)
+ if lp == nil {
+ continue
+ }
+ models[i].Pricing = synthesizePricingFromLiteLLM(lp)
+ }
+}
+
+// synthesizePricingFromLiteLLM 把 LiteLLM 的定价数据转成 ChannelModelPricing 形态,
+// 仅用于展示。BillingMode 固定为 token;图片场景的 OutputCostPerImageToken 也归到
+// ImageOutputPrice 字段(与渠道侧"图片输出按 token 计价"语义一致)。
+//
+// LiteLLM 中字段 0 视为未配置,不带入展示。
+func synthesizePricingFromLiteLLM(lp *LiteLLMModelPricing) *ChannelModelPricing {
+ if lp == nil {
+ return nil
+ }
+ return &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: nonZeroPtr(lp.InputCostPerToken),
+ OutputPrice: nonZeroPtr(lp.OutputCostPerToken),
+ CacheWritePrice: nonZeroPtr(lp.CacheCreationInputTokenCost),
+ CacheReadPrice: nonZeroPtr(lp.CacheReadInputTokenCost),
+ ImageOutputPrice: nonZeroPtr(lp.OutputCostPerImageToken),
+ }
+}
+
+func nonZeroPtr(v float64) *float64 {
+ if v == 0 {
+ return nil
+ }
+ return &v
+}
diff --git a/backend/internal/service/channel_available_test.go b/backend/internal/service/channel_available_test.go
new file mode 100644
index 00000000..8be70ceb
--- /dev/null
+++ b/backend/internal/service/channel_available_test.go
@@ -0,0 +1,177 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+// stubGroupRepoForAvailable 是 ListAvailable 测试用的 GroupRepository stub,
+// 仅实现 ListActive;其他方法对本测试无关,返回零值即可。
+// listActiveErr 非 nil 时,ListActive 返回该错误用于错误传播测试。
+// listActiveCalls 记录调用次数,用于断言「失败短路时不再访问 groupRepo」等行为。
+type stubGroupRepoForAvailable struct {
+ activeGroups []Group
+ listActiveErr error
+ listActiveCalls int
+}
+
+func (s *stubGroupRepoForAvailable) ListActive(ctx context.Context) ([]Group, error) {
+ s.listActiveCalls++
+ if s.listActiveErr != nil {
+ return nil, s.listActiveErr
+ }
+ return s.activeGroups, nil
+}
+
+func (s *stubGroupRepoForAvailable) Create(ctx context.Context, group *Group) error { return nil }
+func (s *stubGroupRepoForAvailable) GetByID(ctx context.Context, id int64) (*Group, error) {
+ return nil, nil
+}
+func (s *stubGroupRepoForAvailable) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
+ return nil, nil
+}
+func (s *stubGroupRepoForAvailable) Update(ctx context.Context, group *Group) error { return nil }
+func (s *stubGroupRepoForAvailable) Delete(ctx context.Context, id int64) error { return nil }
+func (s *stubGroupRepoForAvailable) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
+ return nil, nil
+}
+func (s *stubGroupRepoForAvailable) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *stubGroupRepoForAvailable) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *stubGroupRepoForAvailable) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
+ return nil, nil
+}
+func (s *stubGroupRepoForAvailable) ExistsByName(ctx context.Context, name string) (bool, error) {
+ return false, nil
+}
+func (s *stubGroupRepoForAvailable) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) {
+ return 0, 0, nil
+}
+func (s *stubGroupRepoForAvailable) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
+ return 0, nil
+}
+func (s *stubGroupRepoForAvailable) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
+ return nil, nil
+}
+func (s *stubGroupRepoForAvailable) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
+ return nil
+}
+func (s *stubGroupRepoForAvailable) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
+ return nil
+}
+
+// newAvailableChannelService 构造一个 ChannelService,channelRepo.ListAll 返回给定 channels,
+// groupRepo 由参数决定。传入空 stub 表示「活跃分组列表为空」。
+func newAvailableChannelService(channels []Channel, groupRepo GroupRepository) *ChannelService {
+ repo := &mockChannelRepository{
+ listAllFn: func(ctx context.Context) ([]Channel, error) { return channels, nil },
+ }
+ return NewChannelService(repo, groupRepo, nil, nil)
+}
+
+func TestListAvailable_EmptyActiveGroups_NoGroupsAttached(t *testing.T) {
+ // 活跃分组列表为空时,渠道的 Groups 应为空切片,不报错。
+ channels := []Channel{{
+ ID: 1,
+ Name: "chA",
+ Status: StatusActive,
+ GroupIDs: []int64{10, 20},
+ }}
+ svc := newAvailableChannelService(channels, &stubGroupRepoForAvailable{})
+ out, err := svc.ListAvailable(context.Background())
+ require.NoError(t, err)
+ require.Len(t, out, 1)
+ require.Empty(t, out[0].Groups)
+}
+
+func TestListAvailable_InactiveGroupIDSilentlyDropped(t *testing.T) {
+ // 渠道 GroupIDs 中引用的 group 未出现在 ListActive 结果中(已停用或删除),应被静默丢弃。
+ channels := []Channel{{
+ ID: 1,
+ Name: "chA",
+ Status: StatusActive,
+ GroupIDs: []int64{1, 99},
+ }}
+ groupRepo := &stubGroupRepoForAvailable{
+ activeGroups: []Group{{ID: 1, Name: "g1", Platform: "anthropic"}},
+ }
+ svc := newAvailableChannelService(channels, groupRepo)
+ out, err := svc.ListAvailable(context.Background())
+ require.NoError(t, err)
+ require.Len(t, out, 1)
+ require.Len(t, out[0].Groups, 1)
+ require.Equal(t, int64(1), out[0].Groups[0].ID)
+}
+
+func TestListAvailable_SortedByName(t *testing.T) {
+ channels := []Channel{
+ {ID: 1, Name: "beta"},
+ {ID: 2, Name: "Alpha"},
+ {ID: 3, Name: "charlie"},
+ }
+ svc := newAvailableChannelService(channels, &stubGroupRepoForAvailable{})
+ out, err := svc.ListAvailable(context.Background())
+ require.NoError(t, err)
+ require.Len(t, out, 3)
+ require.Equal(t, "Alpha", out[0].Name)
+ require.Equal(t, "beta", out[1].Name)
+ require.Equal(t, "charlie", out[2].Name)
+}
+
+func TestListAvailable_ListAllErrorPropagates(t *testing.T) {
+ // ListAll 返回错误时 ListAvailable 应直接返回包装后的错误,且不再访问 groupRepo(短路)。
+ sentinel := errors.New("list-all-boom")
+ repo := &mockChannelRepository{
+ listAllFn: func(ctx context.Context) ([]Channel, error) { return nil, sentinel },
+ }
+ groupRepo := &stubGroupRepoForAvailable{}
+ svc := NewChannelService(repo, groupRepo, nil, nil)
+ out, err := svc.ListAvailable(context.Background())
+ require.Nil(t, out)
+ require.ErrorIs(t, err, sentinel)
+ require.Contains(t, err.Error(), "list channels", "wrap 前缀缺失,可能 %w 被改为 %v")
+ require.Equal(t, 0, groupRepo.listActiveCalls, "ListAll 失败后不应再调用 groupRepo.ListActive")
+}
+
+func TestListAvailable_ListActiveErrorPropagates(t *testing.T) {
+ // groupRepo.ListActive 返回错误时 ListAvailable 应直接返回包装后的错误。
+ sentinel := errors.New("list-active-boom")
+ svc := newAvailableChannelService(
+ []Channel{{ID: 1, Name: "chA"}},
+ &stubGroupRepoForAvailable{listActiveErr: sentinel},
+ )
+ out, err := svc.ListAvailable(context.Background())
+ require.Nil(t, out)
+ require.ErrorIs(t, err, sentinel)
+ require.Contains(t, err.Error(), "list active groups", "wrap 前缀缺失,可能 %w 被改为 %v")
+}
+
+func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) {
+ // 渠道 BillingModelSource 为空时应回填为 BillingModelSourceChannelMapped,
+ // 显式值应原样保留(由 service 层统一处理,避免各 handler 重复默认逻辑)。
+ channels := []Channel{
+ {ID: 1, Name: "empty", BillingModelSource: ""},
+ {ID: 2, Name: "explicit", BillingModelSource: BillingModelSourceUpstream},
+ }
+ svc := newAvailableChannelService(channels, &stubGroupRepoForAvailable{})
+ out, err := svc.ListAvailable(context.Background())
+ require.NoError(t, err)
+ require.Len(t, out, 2)
+
+ // 按 Name 查找,避免依赖排序副作用。
+ byName := make(map[string]string, len(out))
+ for _, ch := range out {
+ byName[ch.Name] = ch.BillingModelSource
+ }
+ require.Equal(t, BillingModelSourceChannelMapped, byName["empty"])
+ require.Equal(t, BillingModelSourceUpstream, byName["explicit"])
+}
diff --git a/backend/internal/service/channel_monitor_aggregator.go b/backend/internal/service/channel_monitor_aggregator.go
new file mode 100644
index 00000000..09020f5f
--- /dev/null
+++ b/backend/internal/service/channel_monitor_aggregator.go
@@ -0,0 +1,292 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+)
+
+// 渠道监控聚合层:把 latest + availability 拼成 admin/user 视图所需的 summary / detail。
+// 所有方法都遵守"失败仅日志,返回零值"的原则,避免 N+1 查询失败拖垮列表渲染。
+
+// BatchMonitorStatusSummary 批量聚合多个监控的 latest + 7d 可用率(admin/user list 用,消除 N+1)。
+// 失败时返回空 map,错误仅日志,不影响列表渲染。
+//
+// 参数:
+// - ids: 要聚合的 monitor ID 列表
+// - primaryByID: monitor ID -> primary model(用于读 7d 可用率与 latest 状态)
+// - extrasByID: monitor ID -> extra models 列表(用于读 latest 状态填充 ExtraModels)
+func (s *ChannelMonitorService) BatchMonitorStatusSummary(
+ ctx context.Context,
+ ids []int64,
+ primaryByID map[int64]string,
+ extrasByID map[int64][]string,
+) map[int64]MonitorStatusSummary {
+ out := make(map[int64]MonitorStatusSummary, len(ids))
+ if len(ids) == 0 {
+ return out
+ }
+ latestMap, err := s.repo.ListLatestForMonitorIDs(ctx, ids)
+ if err != nil {
+ slog.Warn("channel_monitor: batch load latest failed", "error", err)
+ latestMap = map[int64][]*ChannelMonitorLatest{}
+ }
+ availMap, err := s.repo.ComputeAvailabilityForMonitors(ctx, ids, monitorAvailability7Days)
+ if err != nil {
+ slog.Warn("channel_monitor: batch compute availability failed", "error", err)
+ availMap = map[int64][]*ChannelMonitorAvailability{}
+ }
+
+ for _, id := range ids {
+ out[id] = buildStatusSummary(
+ indexLatestByModel(latestMap[id]),
+ indexAvailabilityByModel(availMap[id]),
+ primaryByID[id],
+ extrasByID[id],
+ )
+ }
+ return out
+}
+
+// ListUserView 用户只读视图:列出所有 enabled 监控的概览。
+// 使用批量聚合接口避免 N+1:
+//
+// 1 次查 monitors;
+// 1 次批量 latest(含 ping_latency_ms);
+// 1 次批量 7d availability;
+// 1 次批量 timeline(主模型最近 N 条)。
+func (s *ChannelMonitorService) ListUserView(ctx context.Context) ([]*UserMonitorView, error) {
+ monitors, err := s.repo.ListEnabled(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list enabled monitors: %w", err)
+ }
+ if len(monitors) == 0 {
+ return []*UserMonitorView{}, nil
+ }
+
+ ids, primaryByID, extrasByID := collectMonitorIndexes(monitors)
+ summaries := s.BatchMonitorStatusSummary(ctx, ids, primaryByID, extrasByID)
+ latestMap := s.batchLatest(ctx, ids)
+ timelineMap := s.batchTimeline(ctx, ids, primaryByID)
+
+ views := make([]*UserMonitorView, 0, len(monitors))
+ for _, m := range monitors {
+ primaryLatest := pickLatest(latestMap[m.ID], m.PrimaryModel)
+ views = append(views, buildUserViewFromSummary(m, summaries[m.ID], primaryLatest, timelineMap[m.ID]))
+ }
+ return views, nil
+}
+
+// collectMonitorIndexes 把 monitors 列表按 ID 展开为聚合查询所需的三个索引结构。
+func collectMonitorIndexes(monitors []*ChannelMonitor) ([]int64, map[int64]string, map[int64][]string) {
+ ids := make([]int64, 0, len(monitors))
+ primaryByID := make(map[int64]string, len(monitors))
+ extrasByID := make(map[int64][]string, len(monitors))
+ for _, m := range monitors {
+ ids = append(ids, m.ID)
+ primaryByID[m.ID] = m.PrimaryModel
+ extrasByID[m.ID] = m.ExtraModels
+ }
+ return ids, primaryByID, extrasByID
+}
+
+// batchLatest 批量取 latest per model,失败仅日志(与现有 BatchMonitorStatusSummary 一致,不阻断列表渲染)。
+func (s *ChannelMonitorService) batchLatest(ctx context.Context, ids []int64) map[int64][]*ChannelMonitorLatest {
+ latestMap, err := s.repo.ListLatestForMonitorIDs(ctx, ids)
+ if err != nil {
+ slog.Warn("channel_monitor: user view batch latest failed", "error", err)
+ return map[int64][]*ChannelMonitorLatest{}
+ }
+ return latestMap
+}
+
+// batchTimeline 批量取每个 monitor 主模型最近 monitorTimelineMaxPoints 条历史。
+func (s *ChannelMonitorService) batchTimeline(
+ ctx context.Context,
+ ids []int64,
+ primaryByID map[int64]string,
+) map[int64][]*ChannelMonitorHistoryEntry {
+ timelineMap, err := s.repo.ListRecentHistoryForMonitors(ctx, ids, primaryByID, monitorTimelineMaxPoints)
+ if err != nil {
+ slog.Warn("channel_monitor: user view batch timeline failed", "error", err)
+ return map[int64][]*ChannelMonitorHistoryEntry{}
+ }
+ return timelineMap
+}
+
+// pickLatest 从 latest 切片中挑出指定 model 对应项,未命中返回 nil。
+func pickLatest(rows []*ChannelMonitorLatest, model string) *ChannelMonitorLatest {
+ if model == "" {
+ return nil
+ }
+ for _, r := range rows {
+ if r.Model == model {
+ return r
+ }
+ }
+ return nil
+}
+
+// GetUserDetail 用户只读视图:单个监控详情(每个模型 7d/15d/30d 可用率与平均延迟)。
+// 不暴露 api_key。
+func (s *ChannelMonitorService) GetUserDetail(ctx context.Context, id int64) (*UserMonitorDetail, error) {
+ m, err := s.repo.GetByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ if !m.Enabled {
+ return nil, ErrChannelMonitorNotFound
+ }
+
+ latest, err := s.repo.ListLatestPerModel(ctx, id)
+ if err != nil {
+ return nil, fmt.Errorf("list latest per model: %w", err)
+ }
+ availMap, err := s.collectAvailabilityWindows(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+
+ models := mergeModelDetails(m, latest, availMap)
+ return &UserMonitorDetail{
+ ID: m.ID,
+ Name: m.Name,
+ Provider: m.Provider,
+ GroupName: m.GroupName,
+ Models: models,
+ }, nil
+}
+
+// collectAvailabilityWindows 一次性查询 7/15/30 天三个窗口,按模型组织。
+func (s *ChannelMonitorService) collectAvailabilityWindows(ctx context.Context, monitorID int64) (map[int]map[string]*ChannelMonitorAvailability, error) {
+ out := make(map[int]map[string]*ChannelMonitorAvailability, 3)
+ windows := []int{monitorAvailability7Days, monitorAvailability15Days, monitorAvailability30Days}
+ for _, w := range windows {
+ rows, err := s.repo.ComputeAvailability(ctx, monitorID, w)
+ if err != nil {
+ return nil, fmt.Errorf("compute availability %dd: %w", w, err)
+ }
+ out[w] = indexAvailabilityByModel(rows)
+ }
+ return out, nil
+}
+
+// ---------- 纯函数 helper(无 IO,可在 batch / 单 monitor / detail 路径复用)----------
+
+// indexLatestByModel 把 latest 切片按 model 索引(小工具,避免在 hot path 重复写)。
+func indexLatestByModel(rows []*ChannelMonitorLatest) map[string]*ChannelMonitorLatest {
+ m := make(map[string]*ChannelMonitorLatest, len(rows))
+ for _, r := range rows {
+ m[r.Model] = r
+ }
+ return m
+}
+
+// indexAvailabilityByModel 把 availability 切片按 model 索引。
+func indexAvailabilityByModel(rows []*ChannelMonitorAvailability) map[string]*ChannelMonitorAvailability {
+ m := make(map[string]*ChannelMonitorAvailability, len(rows))
+ for _, r := range rows {
+ m[r.Model] = r
+ }
+ return m
+}
+
+// buildStatusSummary 由 latest + availability 字典构造 MonitorStatusSummary。
+// 不做任何 IO,纯组装,便于在 batch 与单 monitor 路径复用。
+func buildStatusSummary(
+ latestByModel map[string]*ChannelMonitorLatest,
+ availByModel map[string]*ChannelMonitorAvailability,
+ primary string,
+ extras []string,
+) MonitorStatusSummary {
+ summary := MonitorStatusSummary{ExtraModels: make([]ExtraModelStatus, 0, len(extras))}
+ if primary != "" {
+ if l, ok := latestByModel[primary]; ok {
+ summary.PrimaryStatus = l.Status
+ summary.PrimaryLatencyMs = l.LatencyMs
+ }
+ if a, ok := availByModel[primary]; ok {
+ summary.Availability7d = a.AvailabilityPct
+ }
+ }
+ for _, model := range extras {
+ entry := ExtraModelStatus{Model: model}
+ if l, ok := latestByModel[model]; ok {
+ entry.Status = l.Status
+ entry.LatencyMs = l.LatencyMs
+ }
+ summary.ExtraModels = append(summary.ExtraModels, entry)
+ }
+ return summary
+}
+
+// buildUserViewFromSummary 用预聚合好的 MonitorStatusSummary + 主模型 latest + timeline 装填 UserMonitorView(无 IO)。
+// primaryLatest 可能为 nil(该监控尚无历史);timelineEntries 可能为空。
+func buildUserViewFromSummary(
+ m *ChannelMonitor,
+ summary MonitorStatusSummary,
+ primaryLatest *ChannelMonitorLatest,
+ timelineEntries []*ChannelMonitorHistoryEntry,
+) *UserMonitorView {
+ view := &UserMonitorView{
+ ID: m.ID,
+ Name: m.Name,
+ Provider: m.Provider,
+ GroupName: m.GroupName,
+ PrimaryModel: m.PrimaryModel,
+ PrimaryStatus: summary.PrimaryStatus,
+ PrimaryLatencyMs: summary.PrimaryLatencyMs,
+ Availability7d: summary.Availability7d,
+ ExtraModels: summary.ExtraModels,
+ Timeline: buildTimelinePoints(timelineEntries),
+ }
+ if primaryLatest != nil {
+ view.PrimaryPingLatencyMs = primaryLatest.PingLatencyMs
+ }
+ return view
+}
+
+// buildTimelinePoints 把 history entry 裁剪为 timeline 点(去除 message/ID/Model,减小响应体)。
+func buildTimelinePoints(entries []*ChannelMonitorHistoryEntry) []UserMonitorTimelinePoint {
+ out := make([]UserMonitorTimelinePoint, 0, len(entries))
+ for _, e := range entries {
+ out = append(out, UserMonitorTimelinePoint{
+ Status: e.Status,
+ LatencyMs: e.LatencyMs,
+ PingLatencyMs: e.PingLatencyMs,
+ CheckedAt: e.CheckedAt,
+ })
+ }
+ return out
+}
+
+// mergeModelDetails 合并 latest + availability 三个窗口为 ModelDetail 列表。
+// 复用 indexLatestByModel,避免在多处重复写 build map 逻辑。
+func mergeModelDetails(
+ m *ChannelMonitor,
+ latest []*ChannelMonitorLatest,
+ availMap map[int]map[string]*ChannelMonitorAvailability,
+) []ModelDetail {
+ all := append([]string{m.PrimaryModel}, m.ExtraModels...)
+ latestByModel := indexLatestByModel(latest)
+ out := make([]ModelDetail, 0, len(all))
+ for _, model := range all {
+ d := ModelDetail{Model: model}
+ if l, ok := latestByModel[model]; ok {
+ d.LatestStatus = l.Status
+ d.LatestLatencyMs = l.LatencyMs
+ }
+ if a, ok := availMap[monitorAvailability7Days][model]; ok {
+ d.Availability7d = a.AvailabilityPct
+ d.AvgLatency7dMs = a.AvgLatencyMs
+ }
+ if a, ok := availMap[monitorAvailability15Days][model]; ok {
+ d.Availability15d = a.AvailabilityPct
+ }
+ if a, ok := availMap[monitorAvailability30Days][model]; ok {
+ d.Availability30d = a.AvailabilityPct
+ }
+ out = append(out, d)
+ }
+ return out
+}
diff --git a/backend/internal/service/channel_monitor_challenge.go b/backend/internal/service/channel_monitor_challenge.go
new file mode 100644
index 00000000..e81a9e2a
--- /dev/null
+++ b/backend/internal/service/channel_monitor_challenge.go
@@ -0,0 +1,80 @@
+package service
+
+import (
+ "fmt"
+ "math/rand/v2"
+ "regexp"
+ "strconv"
+)
+
+// monitorChallengePromptTemplate 1:1 复刻 BingZi-233/check-cx 的 few-shot 模板。
+const monitorChallengePromptTemplate = `Calculate and respond with ONLY the number, nothing else.
+
+Q: 3 + 5 = ?
+A: 8
+
+Q: 12 - 7 = ?
+A: 5
+
+Q: %d %s %d = ?
+A:`
+
+// monitorChallengeNumberRegex 提取响应中的所有整数(含负号)。
+var monitorChallengeNumberRegex = regexp.MustCompile(`-?\d+`)
+
+// monitorChallenge 一次 challenge 的 prompt + 期望答案。
+type monitorChallenge struct {
+ Prompt string
+ Expected string
+}
+
+// generateChallenge 生成一次随机算术 challenge:
+// - 随机两个 [monitorChallengeMin, monitorChallengeMax] 整数
+// - 50% 加 / 50% 减;减法用 max - min 保证非负
+// - 渲染 few-shot 模板
+//
+// 不强求加密随机:math/rand/v2 足够分散,避免 crypto/rand 的开销。
+func generateChallenge() monitorChallenge {
+ a := randIntInRange(monitorChallengeMin, monitorChallengeMax)
+ b := randIntInRange(monitorChallengeMin, monitorChallengeMax)
+
+ if rand.IntN(2) == 0 { //nolint:gosec // 仅用于生成测试问题,无安全影响
+ // 加法
+ return monitorChallenge{
+ Prompt: fmt.Sprintf(monitorChallengePromptTemplate, a, "+", b),
+ Expected: strconv.Itoa(a + b),
+ }
+ }
+
+ // 减法,保证非负
+ hi, lo := a, b
+ if lo > hi {
+ hi, lo = lo, hi
+ }
+ return monitorChallenge{
+ Prompt: fmt.Sprintf(monitorChallengePromptTemplate, hi, "-", lo),
+ Expected: strconv.Itoa(hi - lo),
+ }
+}
+
+// randIntInRange 返回 [min, max] 闭区间的随机整数。
+func randIntInRange(minVal, maxVal int) int {
+ if maxVal <= minVal {
+ return minVal
+ }
+ return minVal + rand.IntN(maxVal-minVal+1) //nolint:gosec
+}
+
+// validateChallenge 在响应文本中查找 expected 整数答案,返回是否通过校验。
+func validateChallenge(responseText, expected string) bool {
+ if responseText == "" || expected == "" {
+ return false
+ }
+ matches := monitorChallengeNumberRegex.FindAllString(responseText, -1)
+ for _, m := range matches {
+ if m == expected {
+ return true
+ }
+ }
+ return false
+}
diff --git a/backend/internal/service/channel_monitor_checker.go b/backend/internal/service/channel_monitor_checker.go
new file mode 100644
index 00000000..33570629
--- /dev/null
+++ b/backend/internal/service/channel_monitor_checker.go
@@ -0,0 +1,443 @@
+package service
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "regexp"
+ "strings"
+ "time"
+
+ "github.com/tidwall/gjson"
+)
+
+// monitorHTTPClient 共享一个 http.Client,避免每次检测重建 transport。
+// 自定义 Transport 在 dial 时强制再次校验 IP,防止 DNS rebinding 绕过 validateEndpoint。
+var monitorHTTPClient = newSSRFSafeHTTPClient(monitorRequestTimeout)
+
+// monitorPingHTTPClient 用于 endpoint origin 的 HEAD ping,超时更短。
+var monitorPingHTTPClient = newSSRFSafeHTTPClient(monitorPingTimeout)
+
+// newSSRFSafeHTTPClient 返回一个使用 safeDialContext 的 http.Client。
+// 仅供监控模块对外发起请求使用——所有目标都应是公网 endpoint。
+func newSSRFSafeHTTPClient(timeout time.Duration) *http.Client {
+ tr := &http.Transport{
+ DialContext: safeDialContext,
+ ForceAttemptHTTP2: true,
+ MaxIdleConns: 16,
+ IdleConnTimeout: monitorIdleConnTimeout,
+ TLSHandshakeTimeout: monitorTLSHandshakeTimeout,
+ ResponseHeaderTimeout: monitorResponseHeaderTimeout,
+ }
+ return &http.Client{Timeout: timeout, Transport: tr}
+}
+
+// CheckOptions 承载一次检测的自定义入参。
+// 所有字段都是可选(零值即等价于"用默认行为")。
+type CheckOptions struct {
+ // ExtraHeaders 用户自定义 HTTP 头(merge 到 adapter 默认 headers,用户优先)。
+ ExtraHeaders map[string]string
+ // BodyOverrideMode: off | merge | replace
+ BodyOverrideMode string
+ // BodyOverride 在 merge 模式下做浅合并(key 命中黑名单时静默丢弃),
+ // 在 replace 模式下直接当作完整 body。
+ BodyOverride map[string]any
+}
+
+// runCheckForModel 对单个 (provider, model) 做一次完整检测。
+// 不返回 error:所有失败都包装进 CheckResult.Status=error/failed。
+//
+// opts 承载模板 / 监控快照带来的自定义配置。nil 等同于 "off + 无 extra headers"。
+func runCheckForModel(ctx context.Context, provider, endpoint, apiKey, model string, opts *CheckOptions) *CheckResult {
+ res := &CheckResult{
+ Model: model,
+ Status: MonitorStatusError,
+ CheckedAt: time.Now(),
+ }
+
+ challenge := generateChallenge()
+ mode := bodyOverrideMode(opts)
+
+ start := time.Now()
+ respText, rawBody, statusCode, err := callProvider(ctx, provider, endpoint, apiKey, model, challenge.Prompt, opts)
+ latency := time.Since(start)
+ latencyMs := int(latency / time.Millisecond)
+ res.LatencyMs = &latencyMs
+
+ if err != nil {
+ res.Status = MonitorStatusError
+ res.Message = truncateMessage(sanitizeErrorMessage(err.Error()))
+ return res
+ }
+ if statusCode < 200 || statusCode >= 300 {
+ // 错误路径:用 rawBody 而非 respText(gjson textPath 抽取在错误响应里通常为空,
+ // 会丢掉真正的上游错误信息,例如 `{"error":{"message":"No available accounts ..."}}`)。
+ res.Status = MonitorStatusError
+ bodySnippet := truncateForErrorBody(rawBody)
+ res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("upstream HTTP %d: %s", statusCode, bodySnippet)))
+ return res
+ }
+
+ // Replace 模式:跳过 challenge 校验(用户 body 是静态的,challenge 没法嵌入)。
+ // 改用「HTTP 2xx + 响应文本(adapter.textPath 抽取)非空」作为 operational 判定。
+ // 响应文本为空则降级为 failed(视为上游回了 200 但没实际内容)。
+ if mode == MonitorBodyOverrideModeReplace {
+ if strings.TrimSpace(respText) == "" {
+ res.Status = MonitorStatusFailed
+ res.Message = truncateMessage("replace-mode: upstream returned 2xx with empty text")
+ return res
+ }
+ return finalizeOperationalOrDegraded(res, latency, latencyMs)
+ }
+
+ if !validateChallenge(respText, challenge.Expected) {
+ res.Status = MonitorStatusFailed
+ res.Message = truncateMessage(sanitizeErrorMessage(fmt.Sprintf("challenge mismatch (expected %s, got %q)", challenge.Expected, respText)))
+ return res
+ }
+
+ return finalizeOperationalOrDegraded(res, latency, latencyMs)
+}
+
+// finalizeOperationalOrDegraded 负责走到最后一步的 operational/degraded 判定。
+// 拆出来是为了让 runCheckForModel 不超过 30 行。
+func finalizeOperationalOrDegraded(res *CheckResult, latency time.Duration, latencyMs int) *CheckResult {
+ if latency >= monitorDegradedThreshold {
+ res.Status = MonitorStatusDegraded
+ res.Message = truncateMessage(fmt.Sprintf("slow response: %dms", latencyMs))
+ return res
+ }
+ res.Status = MonitorStatusOperational
+ return res
+}
+
+// bodyOverrideMode 归一取 opts.BodyOverrideMode,nil opts / 空串都视为 off。
+func bodyOverrideMode(opts *CheckOptions) string {
+ if opts == nil || opts.BodyOverrideMode == "" {
+ return MonitorBodyOverrideModeOff
+ }
+ return opts.BodyOverrideMode
+}
+
+// pingEndpointOrigin 对 endpoint 的 origin (scheme://host) 发起 HEAD 请求,返回耗时。
+// 失败时返回 nil(不影响主状态判定)。
+func pingEndpointOrigin(ctx context.Context, endpoint string) *int {
+ origin, err := extractOrigin(endpoint)
+ if err != nil || origin == "" {
+ return nil
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodHead, origin, nil)
+ if err != nil {
+ return nil
+ }
+ start := time.Now()
+ resp, err := monitorPingHTTPClient.Do(req)
+ if err != nil {
+ return nil
+ }
+ defer func() { _ = resp.Body.Close() }()
+ _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, monitorPingDiscardMaxBytes))
+ ms := int(time.Since(start) / time.Millisecond)
+ return &ms
+}
+
+// providerAdapter 描述某个 provider 在 challenge 检测中需要的 4 件事:
+// - 拼出请求路径(含 model 占位)
+// - 序列化请求体
+// - 构造鉴权头
+// - 从响应 JSON 中按 path 提取文本(gjson path)
+//
+// 加新 provider 只需要在 providerAdapters 里增加一个条目,无需触碰 callProvider / validateProvider。
+type providerAdapter struct {
+ buildPath func(model string) string
+ buildBody func(model, prompt string) ([]byte, error)
+ buildHeaders func(apiKey string) map[string]string
+ textPath string // gjson 提取响应文本的 path
+}
+
+// providerAdapters 全部已支持的 provider。键值即 MonitorProvider* 字符串。
+//
+//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
+var providerAdapters = map[string]providerAdapter{
+ MonitorProviderOpenAI: {
+ buildPath: func(string) string { return providerOpenAIPath },
+ buildBody: func(model, prompt string) ([]byte, error) {
+ return json.Marshal(map[string]any{
+ "model": model,
+ "messages": []map[string]string{{"role": "user", "content": prompt}},
+ "max_tokens": monitorChallengeMaxTokens,
+ "stream": false,
+ })
+ },
+ buildHeaders: func(apiKey string) map[string]string {
+ return map[string]string{"Authorization": "Bearer " + apiKey}
+ },
+ textPath: "choices.0.message.content",
+ },
+ MonitorProviderAnthropic: {
+ buildPath: func(string) string { return providerAnthropicPath },
+ buildBody: func(model, prompt string) ([]byte, error) {
+ return json.Marshal(map[string]any{
+ "model": model,
+ "messages": []map[string]string{{"role": "user", "content": prompt}},
+ "max_tokens": monitorChallengeMaxTokens,
+ })
+ },
+ buildHeaders: func(apiKey string) map[string]string {
+ return map[string]string{
+ "x-api-key": apiKey,
+ "anthropic-version": monitorAnthropicAPIVersion,
+ }
+ },
+ textPath: "content.0.text",
+ },
+ MonitorProviderGemini: {
+ // Gemini 把 model 名写在 URL path 上:/v1beta/models/{model}:generateContent
+ buildPath: func(model string) string { return fmt.Sprintf(providerGeminiPathTemplate, model) },
+ buildBody: func(_, prompt string) ([]byte, error) {
+ return json.Marshal(map[string]any{
+ "contents": []map[string]any{
+ {"parts": []map[string]any{{"text": prompt}}},
+ },
+ "generationConfig": map[string]any{"maxOutputTokens": monitorChallengeMaxTokens},
+ })
+ },
+ // 使用 x-goog-api-key header 而不是 ?key= query,避免 *url.Error 把 key 回填到错误日志。
+ buildHeaders: func(apiKey string) map[string]string {
+ return map[string]string{"x-goog-api-key": apiKey}
+ },
+ textPath: "candidates.0.content.parts.0.text",
+ },
+}
+
+// isSupportedProvider 校验 provider 字符串是否在 adapter 表中。
+// 供 validate.go 的 validateProvider 复用,避免两份 switch 漂移。
+func isSupportedProvider(p string) bool {
+ _, ok := providerAdapters[p]
+ return ok
+}
+
+// callProvider 通过 providerAdapters 分发到具体实现。
+// opts 承载用户的自定义 headers / body 覆盖(可为 nil)。
+//
+// 返回值:
+// - extractedText: 按 textPath 抽出的成功文本,仅在 status 2xx 时有意义;非 2xx 时通常为空串
+// - rawBody: 完整响应体的字符串形式(已被 monitorResponseMaxBytes 截断),用于错误路径保留上游真实回包
+// - status: HTTP 状态码
+// - err: 网络 / 序列化错误
+func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string, opts *CheckOptions) (extractedText, rawBody string, status int, err error) {
+ adapter, ok := providerAdapters[provider]
+ if !ok {
+ return "", "", 0, fmt.Errorf("unsupported provider %q", provider)
+ }
+ body, err := buildRequestBody(adapter, provider, model, prompt, opts)
+ if err != nil {
+ return "", "", 0, err
+ }
+ headers := mergeHeaders(adapter.buildHeaders(apiKey), opts)
+ full := joinURL(endpoint, adapter.buildPath(model))
+ respBytes, status, err := postRawJSON(ctx, full, body, headers)
+ if err != nil {
+ return "", "", status, err
+ }
+ return gjson.GetBytes(respBytes, adapter.textPath).String(), string(respBytes), status, nil
+}
+
+// mergeHeaders 把用户自定义 headers 合并到 adapter 默认 headers 上。
+// 用户值覆盖默认;命中黑名单(hop-by-hop / 由 http.Client 自管的)的 key 静默丢弃。
+func mergeHeaders(base map[string]string, opts *CheckOptions) map[string]string {
+ if opts == nil || len(opts.ExtraHeaders) == 0 {
+ return base
+ }
+ out := make(map[string]string, len(base)+len(opts.ExtraHeaders))
+ for k, v := range base {
+ out[k] = v
+ }
+ for k, v := range opts.ExtraHeaders {
+ if IsForbiddenHeaderName(k) {
+ continue
+ }
+ out[k] = v
+ }
+ return out
+}
+
+// buildRequestBody 根据 body_override_mode 构造请求 body。
+//
+// - off: adapter 默认 body
+// - merge: adapter 默认 body 与 BodyOverride 浅合并;BodyOverride 中命中
+// bodyMergeKeyDenyList[provider] 的 key 会被静默丢弃,避免破坏 challenge / model 路由
+// - replace: 直接 marshal BodyOverride 作为完整 body
+//
+// 任何 mode 返回的 []byte 都已经是合法 JSON,可直接送入 postRawJSON。
+func buildRequestBody(adapter providerAdapter, provider, model, prompt string, opts *CheckOptions) ([]byte, error) {
+ mode := bodyOverrideMode(opts)
+
+ if mode == MonitorBodyOverrideModeReplace {
+ if opts == nil || len(opts.BodyOverride) == 0 {
+ return nil, fmt.Errorf("replace mode: body_override is empty")
+ }
+ body, err := json.Marshal(opts.BodyOverride)
+ if err != nil {
+ return nil, fmt.Errorf("marshal body_override (replace): %w", err)
+ }
+ return body, nil
+ }
+
+ defaultBody, err := adapter.buildBody(model, prompt)
+ if err != nil {
+ return nil, fmt.Errorf("marshal default body: %w", err)
+ }
+ if mode != MonitorBodyOverrideModeMerge || opts == nil || len(opts.BodyOverride) == 0 {
+ return defaultBody, nil
+ }
+
+ var defaultMap map[string]any
+ if err := json.Unmarshal(defaultBody, &defaultMap); err != nil {
+ return nil, fmt.Errorf("unmarshal default body for merge: %w", err)
+ }
+ deny := bodyMergeKeyDenyList[provider]
+ for k, v := range opts.BodyOverride {
+ if deny[k] {
+ continue
+ }
+ defaultMap[k] = v
+ }
+ merged, err := json.Marshal(defaultMap)
+ if err != nil {
+ return nil, fmt.Errorf("marshal merged body: %w", err)
+ }
+ return merged, nil
+}
+
+// bodyMergeKeyDenyList 在 merge 模式下,禁止用户覆盖这些 provider-specific 的关键字段。
+// 思路抄 check-cx 的 EXCLUDED_METADATA_KEYS:保护 challenge / model 路由不被用户误伤。
+// 用户想动这些字段就用 replace 模式(已知会跳 challenge 校验)。
+//
+//nolint:gochecknoglobals // 静态查表,初始化后不变。
+var bodyMergeKeyDenyList = map[string]map[string]bool{
+ MonitorProviderOpenAI: {"model": true, "messages": true, "stream": true},
+ MonitorProviderAnthropic: {"model": true, "messages": true},
+ MonitorProviderGemini: {"contents": true},
+}
+
+// postRawJSON 发送 POST + 已序列化好的 JSON 字节,限制响应体大小,返回响应字节、HTTP status、错误。
+// adapter 自行 marshal 是为了精确控制字段顺序与类型,所以这里直接收 []byte 而不是 any。
+func postRawJSON(ctx context.Context, fullURL string, payload []byte, headers map[string]string) ([]byte, int, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
+ if err != nil {
+ return nil, 0, fmt.Errorf("build request: %w", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+
+ resp, err := monitorHTTPClient.Do(req)
+ if err != nil {
+ return nil, 0, fmt.Errorf("do request: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ respBody, err := io.ReadAll(io.LimitReader(resp.Body, monitorResponseMaxBytes))
+ if err != nil {
+ return nil, resp.StatusCode, fmt.Errorf("read body: %w", err)
+ }
+ return respBody, resp.StatusCode, nil
+}
+
+// joinURL 把 base origin 与 path 拼成完整 URL。
+// 容忍 base 末尾有/无斜杠,path 必带前导斜杠。
+func joinURL(base, path string) string {
+ base = strings.TrimRight(base, "/")
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+ return base + path
+}
+
+// extractOrigin 从一个 endpoint URL 中提取 scheme://host[:port] 部分。
+func extractOrigin(endpoint string) (string, error) {
+ u, err := url.Parse(endpoint)
+ if err != nil {
+ return "", err
+ }
+ if u.Scheme == "" || u.Host == "" {
+ return "", errors.New("endpoint missing scheme or host")
+ }
+ return u.Scheme + "://" + u.Host, nil
+}
+
+// monitorSensitiveQueryParamRegex 匹配 URL query 中可能泄露凭证的参数:
+// key / api_key / api-key / access_token / token / authorization / x-api-key。
+// 大小写不敏感,匹配 `?name=value` 或 `&name=value` 形式(value 截到 & 或字符串末尾)。
+var monitorSensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|api[_-]?key|access[_-]?token|token|authorization|x-api-key)=)[^&\s"']+`)
+
+// monitorAPIKeyPatterns 匹配常见 provider 的 API key 字面量。
+// 顺序敏感:sk-ant- 必须放在 sk- 之前,否则会被通用 sk- 模式先消费。
+var monitorAPIKeyPatterns = []struct {
+ pattern *regexp.Regexp
+ replace string
+}{
+ // Anthropic(带前缀,必须先匹配):sk-ant-xxxxxxx
+ {regexp.MustCompile(`sk-ant-[A-Za-z0-9_-]{20,}`), "sk-ant-***REDACTED***"},
+ // OpenAI / Anthropic 通用 sk-: sk-xxxxxxx
+ {regexp.MustCompile(`sk-[A-Za-z0-9-]{20,}`), "sk-***REDACTED***"},
+ // Gemini / Google API Key:固定前缀 + 35 位
+ {regexp.MustCompile(`AIza[A-Za-z0-9_-]{35}`), "AIza***REDACTED***"},
+ // JWT 三段式(Bearer 后常出现):eyJxxx.eyJxxx.signature
+ {regexp.MustCompile(`eyJ[A-Za-z0-9_-]{8,}\.eyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{8,}`), "eyJ***REDACTED.JWT***"},
+}
+
+// sanitizeErrorMessage 擦除错误/响应文本中可能泄露的 API key。
+// 处理两类来源:
+// 1. URL query 中的 ?key= / ?api_key= 等(Go *url.Error 会回填完整 URL)
+// 2. 上游 HTTP body 文本里直接出现的 sk-* / AIza* / JWT 等密钥碎片
+//
+// 注意:与 gemini_messages_compat_service.go 的 sanitizeUpstreamErrorMessage 关注点类似但参数集更广,
+// 监控模块独立维护,避免互相耦合。
+func sanitizeErrorMessage(msg string) string {
+ if msg == "" {
+ return msg
+ }
+ msg = monitorSensitiveQueryParamRegex.ReplaceAllString(msg, `${1}REDACTED`)
+ for _, p := range monitorAPIKeyPatterns {
+ msg = p.pattern.ReplaceAllString(msg, p.replace)
+ }
+ return msg
+}
+
+// truncateMessage 把消息按 monitorMessageMaxBytes 截断,避免 DB 列溢出与日志过长。
+func truncateMessage(msg string) string {
+ if len(msg) <= monitorMessageMaxBytes {
+ return msg
+ }
+ const ellipsis = "...(truncated)"
+ cutoff := monitorMessageMaxBytes - len(ellipsis)
+ if cutoff < 0 {
+ cutoff = 0
+ }
+ return msg[:cutoff] + ellipsis
+}
+
+// truncateForErrorBody 把上游错误响应 body 压到 monitorErrorBodySnippetMaxBytes 以内,
+// 并顺手把连续空白折成一个空格:上游 HTML 错误页常含大量缩进/换行,保留会浪费预算。
+// 被 truncateMessage 做最终总截断兜底,所以这里只负责 body 自身的精简。
+func truncateForErrorBody(body string) string {
+ body = strings.Join(strings.Fields(body), " ")
+ if len(body) <= monitorErrorBodySnippetMaxBytes {
+ return body
+ }
+ const ellipsis = "...(body truncated)"
+ cutoff := monitorErrorBodySnippetMaxBytes - len(ellipsis)
+ if cutoff < 0 {
+ cutoff = 0
+ }
+ return body[:cutoff] + ellipsis
+}
diff --git a/backend/internal/service/channel_monitor_checker_body_test.go b/backend/internal/service/channel_monitor_checker_body_test.go
new file mode 100644
index 00000000..323cf8b7
--- /dev/null
+++ b/backend/internal/service/channel_monitor_checker_body_test.go
@@ -0,0 +1,173 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+)
+
+// swapMonitorHTTPClient 临时替换 monitorHTTPClient 为不带 SSRF 校验的普通 client,
+// 让 httptest (127.0.0.1) 能连通。测试结束后恢复。
+func swapMonitorHTTPClient(t *testing.T) {
+ t.Helper()
+ orig := monitorHTTPClient
+ monitorHTTPClient = &http.Client{Timeout: 5 * time.Second}
+ t.Cleanup(func() { monitorHTTPClient = orig })
+}
+
+// captureHandler 把每次收到的请求 body 和 headers 存起来,测试断言用。
+type captureHandler struct {
+ lastBody map[string]any
+ lastHeaders http.Header
+ respondText string // 写到 Anthropic content[0].text 里(校验用)
+ status int
+}
+
+func (h *captureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ h.lastHeaders = r.Header.Clone()
+ defer func() { _ = r.Body.Close() }()
+ var parsed map[string]any
+ _ = json.NewDecoder(r.Body).Decode(&parsed)
+ h.lastBody = parsed
+
+ if h.status == 0 {
+ h.status = 200
+ }
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(h.status)
+ // 构造 Anthropic 格式的响应:content[0].text = h.respondText
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "content": []map[string]any{
+ {"type": "text", "text": h.respondText},
+ },
+ })
+}
+
+func setupFakeAnthropic(t *testing.T, handler *captureHandler) string {
+ t.Helper()
+ swapMonitorHTTPClient(t)
+ srv := httptest.NewServer(handler)
+ t.Cleanup(srv.Close)
+ return srv.URL
+}
+
+func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) {
+ h := &captureHandler{respondText: "the answer is 42"}
+ endpoint := setupFakeAnthropic(t, h)
+
+ // 跑一次 off 模式(opts=nil),确认默认 body 行为未变
+ _ = runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", nil)
+
+ if h.lastBody["model"] != "claude-x" {
+ t.Errorf("default body should contain model=claude-x, got %v", h.lastBody["model"])
+ }
+ if _, ok := h.lastBody["messages"]; !ok {
+ t.Error("default body should contain messages")
+ }
+ if h.lastHeaders.Get("x-api-key") != "sk-fake" {
+ t.Errorf("expected adapter's x-api-key header, got %q", h.lastHeaders.Get("x-api-key"))
+ }
+}
+
+func TestRunCheckForModel_MergeMode_UserFieldsWinButDenyListProtects(t *testing.T) {
+ h := &captureHandler{respondText: "the answer is 42"}
+ endpoint := setupFakeAnthropic(t, h)
+
+ opts := &CheckOptions{
+ BodyOverrideMode: MonitorBodyOverrideModeMerge,
+ BodyOverride: map[string]any{
+ "system": "You are Claude Code...",
+ "max_tokens": float64(999), // 应该覆盖默认 50
+ "model": "hacked-model", // 应该被黑名单挡住,保留原 model
+ "messages": []any{}, // 同上,被挡
+ },
+ ExtraHeaders: map[string]string{
+ "User-Agent": "claude-cli/1.0",
+ "Content-Length": "999", // 黑名单
+ "x-custom": "ok",
+ },
+ }
+ _ = runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts)
+
+ if h.lastBody["system"] != "You are Claude Code..." {
+ t.Errorf("merge mode should inject system, got %v", h.lastBody["system"])
+ }
+ // max_tokens 覆盖生效
+ if mt, ok := h.lastBody["max_tokens"].(float64); !ok || mt != 999 {
+ t.Errorf("merge mode should override max_tokens to 999, got %v", h.lastBody["max_tokens"])
+ }
+ // model 在黑名单 — 应该保留默认值
+ if h.lastBody["model"] != "claude-x" {
+ t.Errorf("model should be protected by deny list, got %v", h.lastBody["model"])
+ }
+ // messages 在黑名单 — 应该保留默认值(非空)
+ msgs, _ := h.lastBody["messages"].([]any)
+ if len(msgs) == 0 {
+ t.Error("messages should be protected by deny list (kept default, non-empty)")
+ }
+ // header 合并
+ if h.lastHeaders.Get("User-Agent") != "claude-cli/1.0" {
+ t.Errorf("extra User-Agent should override, got %q", h.lastHeaders.Get("User-Agent"))
+ }
+ if h.lastHeaders.Get("x-custom") != "ok" {
+ t.Errorf("extra custom header should be present, got %q", h.lastHeaders.Get("x-custom"))
+ }
+ // Content-Length 黑名单:会被 net/http 自动重算,但不应由用户的 "999" 决定。
+ // 我们无法直接断言丢弃(http.Client 总会填上),只断言请求成功即可。
+}
+
+func TestRunCheckForModel_ReplaceMode_FullBodyUsedAndChallengeSkipped(t *testing.T) {
+ // replace 模式下我们的 body 完全自定义,challenge 数学题不会出现在请求里,
+ // 上游也不会回正确答案 — 但只要 2xx + 响应文本非空,就算 operational
+ h := &captureHandler{respondText: "any non-empty text"}
+ endpoint := setupFakeAnthropic(t, h)
+
+ userBody := map[string]any{
+ "model": "user-forced-model",
+ "messages": []any{map[string]any{"role": "user", "content": "hi"}},
+ "max_tokens": float64(10),
+ "system": "You are someone else",
+ }
+ opts := &CheckOptions{
+ BodyOverrideMode: MonitorBodyOverrideModeReplace,
+ BodyOverride: userBody,
+ }
+ res := runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts)
+
+ // 请求 body = 用户提供的原样
+ if h.lastBody["model"] != "user-forced-model" {
+ t.Errorf("replace mode should use user's model, got %v", h.lastBody["model"])
+ }
+ if h.lastBody["system"] != "You are someone else" {
+ t.Errorf("replace mode should use user's system, got %v", h.lastBody["system"])
+ }
+ // challenge 虽然没命中,但由于 replace 模式跳过 challenge 校验 + 响应非空 → operational
+ if res.Status != MonitorStatusOperational {
+ t.Errorf("replace mode with 2xx + non-empty text should be operational, got status=%s message=%q",
+ res.Status, res.Message)
+ }
+}
+
+func TestRunCheckForModel_ReplaceMode_EmptyResponseIsFailed(t *testing.T) {
+ h := &captureHandler{respondText: ""} // 上游 200 但 content[0].text 为空
+ endpoint := setupFakeAnthropic(t, h)
+
+ opts := &CheckOptions{
+ BodyOverrideMode: MonitorBodyOverrideModeReplace,
+ BodyOverride: map[string]any{"model": "x", "messages": []any{}},
+ }
+ res := runCheckForModel(context.Background(), MonitorProviderAnthropic, endpoint, "sk-fake", "claude-x", opts)
+
+ if res.Status != MonitorStatusFailed {
+ t.Errorf("replace mode with empty text should be failed, got status=%s", res.Status)
+ }
+ if !strings.Contains(res.Message, "replace-mode") {
+ t.Errorf("failure message should hint replace-mode, got %q", res.Message)
+ }
+}
diff --git a/backend/internal/service/channel_monitor_const.go b/backend/internal/service/channel_monitor_const.go
new file mode 100644
index 00000000..2e1614f7
--- /dev/null
+++ b/backend/internal/service/channel_monitor_const.go
@@ -0,0 +1,142 @@
+package service
+
+import (
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+// ChannelMonitor 全局常量。
+// 这些是 MVP 阶段的硬编码值,按需可以提到 config 中。
+const (
+ // monitorRequestTimeout 单次模型请求总超时(含 Body 读取)。
+ monitorRequestTimeout = 45 * time.Second
+ // monitorPingTimeout HEAD 请求 endpoint origin 的超时。
+ monitorPingTimeout = 8 * time.Second
+ // monitorDegradedThreshold 主请求成功但耗时超过该阈值视为 degraded。
+ monitorDegradedThreshold = 6 * time.Second
+ // monitorHistoryRetentionDays 明细历史保留天数。
+ // 60s 默认间隔 * 30 天 ≈ 43200 行/monitor/model,一般部署总量 <= 2M 行,
+ // PG 无压力;所以直接保留完整明细一个月,可用率查询可以全走原始行不依赖聚合。
+ // 聚合表 channel_monitor_daily_rollups 仍然保留,作为长期历史回填/降级查询的兜底。
+ monitorHistoryRetentionDays = 30
+ // monitorRollupRetentionDays 日聚合保留天数。
+ // 日聚合行由 RunDailyMaintenance 在超过该窗口后软删。
+ monitorRollupRetentionDays = 30
+ // monitorMaintenanceMaxDaysPerRun 单次维护任务最多聚合的天数。
+ // 用于限制首次上线回填(30 天)+ 少量余量,避免长事务。
+ monitorMaintenanceMaxDaysPerRun = 35
+ // monitorWorkerConcurrency 调度器并发执行的监控数(pond 池容量)。
+ monitorWorkerConcurrency = 5
+ // monitorStartupLoadTimeout Start 时一次性加载所有 enabled monitor 的总超时。
+ monitorStartupLoadTimeout = 10 * time.Second
+ // monitorMinIntervalSeconds / monitorMaxIntervalSeconds 用户配置的检测间隔上下限。
+ monitorMinIntervalSeconds = 15
+ monitorMaxIntervalSeconds = 3600
+ // monitorMessageMaxBytes message 字段最大字节数(与 schema/migration 一致)。
+ monitorMessageMaxBytes = 500
+ // monitorResponseMaxBytes 单次模型响应最大读取字节,防止 OOM。
+ monitorResponseMaxBytes = 64 * 1024
+ // monitorErrorBodySnippetMaxBytes 非 2xx 响应时保留上游 body 片段的最大字节数。
+ // 留 300 字节足够覆盖典型结构化错误(如 `{"error":{"message":"..."}}`),
+ // 又给 "upstream HTTP : " 前缀留出余量,避免最终被 monitorMessageMaxBytes (500) 截得太狠。
+ monitorErrorBodySnippetMaxBytes = 300
+ // monitorChallengeMin / monitorChallengeMax challenge 操作数范围。
+ monitorChallengeMin = 1
+ monitorChallengeMax = 50
+
+ // providerOpenAIPath OpenAI Chat Completions 路径。
+ providerOpenAIPath = "/v1/chat/completions"
+ // providerAnthropicPath Anthropic Messages 路径。
+ providerAnthropicPath = "/v1/messages"
+ // providerGeminiPathTemplate Gemini generateContent 路径模板(含 model 占位)。
+ providerGeminiPathTemplate = "/v1beta/models/%s:generateContent"
+
+ // MonitorProviderOpenAI / Anthropic / Gemini provider 字符串常量(也是 ent enum 的实际值)。
+ MonitorProviderOpenAI = "openai"
+ MonitorProviderAnthropic = "anthropic"
+ MonitorProviderGemini = "gemini"
+
+ // MonitorStatusOperational 等监控状态字符串常量(与 ent enum 一致)。
+ MonitorStatusOperational = "operational"
+ MonitorStatusDegraded = "degraded"
+ MonitorStatusFailed = "failed"
+ MonitorStatusError = "error"
+
+ // monitorAvailability7Days / 15 / 30 用于聚合查询窗口。
+ monitorAvailability7Days = 7
+ monitorAvailability15Days = 15
+ monitorAvailability30Days = 30
+
+ // MonitorHistoryDefaultLimit 历史查询默认返回条数(handler 层共享)。
+ MonitorHistoryDefaultLimit = 100
+ // MonitorHistoryMaxLimit 历史查询最大返回条数(handler 层共享)。
+ MonitorHistoryMaxLimit = 1000
+
+ // monitorTimelineMaxPoints 用户视图 timeline 每个监控最多返回的历史点数。
+ monitorTimelineMaxPoints = 60
+
+ // monitorEndpointResolveTimeout validateEndpoint 解析 hostname 的最长耗时。
+ monitorEndpointResolveTimeout = 5 * time.Second
+
+ // ---- checker / runner 行为参数(消除 magic 值)----
+
+ // monitorAnthropicAPIVersion Anthropic Messages API 版本头。
+ monitorAnthropicAPIVersion = "2023-06-01"
+ // monitorChallengeMaxTokens 单次 challenge 请求的 max_tokens(足够回答个位数算术)。
+ monitorChallengeMaxTokens = 50
+
+ // monitorRunOneBuffer runOne 的总超时缓冲(除请求超时与 ping 超时外的额外裕量)。
+ monitorRunOneBuffer = 10 * time.Second
+
+ // monitorIdleConnTimeout HTTP transport 空闲连接关闭超时。
+ monitorIdleConnTimeout = 30 * time.Second
+ // monitorTLSHandshakeTimeout HTTP transport TLS 握手超时。
+ monitorTLSHandshakeTimeout = 10 * time.Second
+ // monitorResponseHeaderTimeout HTTP transport 等待响应头超时。
+ monitorResponseHeaderTimeout = 30 * time.Second
+ // monitorPingDiscardMaxBytes ping 时丢弃响应体的最大字节数。
+ monitorPingDiscardMaxBytes = 1024
+
+ // monitorDialTimeout 自定义 dialer 单次连接超时。
+ monitorDialTimeout = 10 * time.Second
+ // monitorDialKeepAlive 自定义 dialer keep-alive 间隔。
+ monitorDialKeepAlive = 30 * time.Second
+)
+
+// 业务错误(统一在此声明,避免散落)。
+var (
+ ErrChannelMonitorNotFound = infraerrors.NotFound(
+ "CHANNEL_MONITOR_NOT_FOUND", "channel monitor not found",
+ )
+ ErrChannelMonitorInvalidProvider = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_INVALID_PROVIDER", "provider must be one of openai/anthropic/gemini",
+ )
+ ErrChannelMonitorInvalidInterval = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_INVALID_INTERVAL", "interval_seconds must be in [15, 3600]",
+ )
+ ErrChannelMonitorInvalidEndpoint = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_INVALID_ENDPOINT", "endpoint must be a valid https URL",
+ )
+ ErrChannelMonitorEndpointScheme = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_ENDPOINT_SCHEME", "endpoint must use https scheme",
+ )
+ ErrChannelMonitorEndpointPath = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_ENDPOINT_PATH", "endpoint must be base origin only (no path/query/fragment)",
+ )
+ ErrChannelMonitorEndpointPrivate = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_ENDPOINT_PRIVATE", "endpoint must be a public host",
+ )
+ ErrChannelMonitorEndpointUnreachable = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_ENDPOINT_UNREACHABLE", "endpoint hostname could not be resolved",
+ )
+ ErrChannelMonitorMissingAPIKey = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_MISSING_API_KEY", "api_key is required when creating a monitor",
+ )
+ ErrChannelMonitorMissingPrimaryModel = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_MISSING_PRIMARY_MODEL", "primary_model is required",
+ )
+ ErrChannelMonitorAPIKeyDecryptFailed = infraerrors.InternalServer(
+ "CHANNEL_MONITOR_KEY_DECRYPT_FAILED", "api key decryption failed; please re-edit the monitor with a fresh key",
+ )
+)
diff --git a/backend/internal/service/channel_monitor_runner.go b/backend/internal/service/channel_monitor_runner.go
new file mode 100644
index 00000000..08178bc6
--- /dev/null
+++ b/backend/internal/service/channel_monitor_runner.go
@@ -0,0 +1,291 @@
+package service
+
+import (
+ "context"
+ "log/slog"
+ "sync"
+ "time"
+
+ "github.com/alitto/pond/v2"
+)
+
+// MonitorScheduler 调度器接口,供 ChannelMonitorService 在 CRUD 时回调,
+// 用 setter 注入避免 service ↔ runner 的 wire 依赖环。
+type MonitorScheduler interface {
+ // Schedule 为指定监控创建(或重置)独立定时任务。
+ // 当 m.Enabled=false 时等同于 Unschedule(m.ID)。
+ Schedule(m *ChannelMonitor)
+ // Unschedule 取消指定监控的定时任务(若存在)。
+ Unschedule(id int64)
+}
+
+// monitorRunnerSvc 抽出 runner 实际依赖的两个 service 方法:
+// - 启动时加载 enabled monitor
+// - 每次 ticker 触发执行检测
+//
+// 用接口而非 *ChannelMonitorService 是为了让 runner 单元测试可注入轻量 stub,
+// 避免依赖完整的 repo + encryptor 链路。生产实现 *ChannelMonitorService 自然满足。
+type monitorRunnerSvc interface {
+ ListEnabledMonitors(ctx context.Context) ([]*ChannelMonitor, error)
+ RunCheck(ctx context.Context, id int64) ([]*CheckResult, error)
+}
+
+// ChannelMonitorRunner 渠道监控调度器。
+//
+// 设计:
+// - 每个 enabled monitor 对应一个独立 goroutine + ticker(按各自 IntervalSeconds)
+// - Start 时一次性加载所有 enabled monitor 并为每个建立任务
+// - Service 在 Create/Update/Delete 后通过 MonitorScheduler 接口回调,
+// 即时重建/取消对应任务(无需轮询 DB)
+// - 实际 HTTP 检测交给 pond 池(容量 monitorWorkerConcurrency),
+// 防止突发并发拖垮上游
+//
+// 历史清理与日聚合维护由 OpsCleanupService 的 cron 触发
+// ChannelMonitorService.RunDailyMaintenance(复用 leader lock + heartbeat),
+// 不在 runner 职责内。
+type ChannelMonitorRunner struct {
+ svc monitorRunnerSvc
+ settingService *SettingService
+
+ pool pond.Pool
+ parentCtx context.Context
+ parentCancel context.CancelFunc
+
+ mu sync.Mutex
+ tasks map[int64]*scheduledMonitor
+ wg sync.WaitGroup
+ started bool
+ stopped bool
+
+ // inFlight 跟踪正在执行的 monitor.ID。fire 调度前会检查避免重复提交,
+ // 防止单次检测耗时 > interval 时同一 monitor 被并发执行。
+ inFlight map[int64]struct{}
+ inFlightMu sync.Mutex
+}
+
+// scheduledMonitor 单个监控的运行时上下文。
+type scheduledMonitor struct {
+ id int64
+ name string
+ interval time.Duration
+ cancel context.CancelFunc
+}
+
+// NewChannelMonitorRunner 构造调度器。Start 在 wire 中调用一次。
+// settingService 用于在每次 fire 前读取功能开关;传 nil 时视为总是启用(兼容测试)。
+//
+// pool 在构造时即建好:避免 Start 在 mu 内赋值、fire/Stop 在 mu 外读取的竞态隐患,
+// 且 pond.NewPool 创建本身近似零开销,提前建池不会浪费资源。
+func NewChannelMonitorRunner(svc *ChannelMonitorService, settingService *SettingService) *ChannelMonitorRunner {
+ return newChannelMonitorRunner(svc, settingService)
+}
+
+// newChannelMonitorRunner 内部构造,接受最小化接口,便于单元测试注入 stub。
+func newChannelMonitorRunner(svc monitorRunnerSvc, settingService *SettingService) *ChannelMonitorRunner {
+ ctx, cancel := context.WithCancel(context.Background())
+ return &ChannelMonitorRunner{
+ svc: svc,
+ settingService: settingService,
+ pool: pond.NewPool(monitorWorkerConcurrency),
+ parentCtx: ctx,
+ parentCancel: cancel,
+ tasks: make(map[int64]*scheduledMonitor),
+ inFlight: make(map[int64]struct{}),
+ }
+}
+
+// Start 加载所有 enabled monitor 并为每个建立独立定时任务。
+// 调用方需保证只调一次(wire ProvideChannelMonitorRunner 内只调一次)。
+func (r *ChannelMonitorRunner) Start() {
+ if r == nil || r.svc == nil {
+ return
+ }
+ r.mu.Lock()
+ if r.started || r.stopped {
+ r.mu.Unlock()
+ return
+ }
+ r.started = true
+ r.mu.Unlock()
+
+ ctx, cancel := context.WithTimeout(context.Background(), monitorStartupLoadTimeout)
+ defer cancel()
+ enabled, err := r.svc.ListEnabledMonitors(ctx)
+ if err != nil {
+ slog.Error("channel_monitor: load enabled monitors failed at startup", "error", err)
+ return
+ }
+ for _, m := range enabled {
+ r.Schedule(m)
+ }
+ slog.Info("channel_monitor: runner started", "scheduled_tasks", len(enabled))
+}
+
+// Schedule 为指定监控创建(或重置)独立定时任务。
+// - m.Enabled=false → 等同于 Unschedule(m.ID)
+// - 已存在的任务会先被取消再重建(适用于 IntervalSeconds 变更场景)
+// - 新任务立即触发首次检测,之后按 IntervalSeconds 周期触发
+func (r *ChannelMonitorRunner) Schedule(m *ChannelMonitor) {
+ if r == nil || m == nil {
+ return
+ }
+ if !m.Enabled {
+ r.Unschedule(m.ID)
+ return
+ }
+ interval := time.Duration(m.IntervalSeconds) * time.Second
+ if interval <= 0 {
+ // Create/Update 已通过 validateInterval 校验区间,正常路径不可能到这里。
+ // 真触发说明数据库中存在违反约束的数据或校验链路有 bug,记 Error 暴露问题。
+ slog.Error("channel_monitor: skip schedule for invalid interval",
+ "monitor_id", m.ID, "interval_seconds", m.IntervalSeconds)
+ return
+ }
+
+ r.mu.Lock()
+ if r.stopped {
+ r.mu.Unlock()
+ return
+ }
+ if !r.started {
+ // Start 之前调用 Schedule 通常意味着 wire 顺序错乱:
+ // 当前 wire 顺序是 SetScheduler → Start,CRUD 钩子最早也只能在请求到达时触发,
+ // 此时 Start 早已完成。出现此分支时把 monitor 信息打出来便于排查,
+ // 不入队、不缓存——交给运维通过重启或修复 wire 解决。
+ r.mu.Unlock()
+ slog.Warn("channel_monitor: schedule before runner started, skip",
+ "monitor_id", m.ID, "name", m.Name)
+ return
+ }
+ if existing, ok := r.tasks[m.ID]; ok {
+ existing.cancel()
+ }
+ ctx, cancel := context.WithCancel(r.parentCtx)
+ task := &scheduledMonitor{
+ id: m.ID,
+ name: m.Name,
+ interval: interval,
+ cancel: cancel,
+ }
+ r.tasks[m.ID] = task
+ r.wg.Add(1)
+ r.mu.Unlock()
+
+ go r.runScheduled(ctx, task)
+}
+
+// Unschedule 取消指定监控的定时任务(若存在)。
+// 已经在执行中的检测会通过 ctx 取消信号传递。
+func (r *ChannelMonitorRunner) Unschedule(id int64) {
+ if r == nil {
+ return
+ }
+ r.mu.Lock()
+ task, ok := r.tasks[id]
+ if ok {
+ delete(r.tasks, id)
+ }
+ r.mu.Unlock()
+ if ok {
+ task.cancel()
+ }
+}
+
+// Stop 优雅停止:取消所有任务、关闭池。
+func (r *ChannelMonitorRunner) Stop() {
+ if r == nil {
+ return
+ }
+ r.mu.Lock()
+ if r.stopped {
+ r.mu.Unlock()
+ return
+ }
+ r.stopped = true
+ r.parentCancel()
+ r.tasks = nil
+ r.mu.Unlock()
+
+ r.wg.Wait()
+ r.pool.StopAndWait()
+}
+
+// runScheduled 单个监控的循环:立即触发首次(满足"新建/启用即跑"),
+// 之后按 interval 周期触发;ctx 取消即退出。
+func (r *ChannelMonitorRunner) runScheduled(ctx context.Context, task *scheduledMonitor) {
+ defer r.wg.Done()
+
+ r.fire(ctx, task)
+
+ ticker := time.NewTicker(task.interval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ r.fire(ctx, task)
+ }
+ }
+}
+
+// fire 提交一次检测到 worker 池。功能开关关闭时跳过本次(不取消任务,
+// 重新启用时立即恢复);池满或重复在飞时也跳过。
+func (r *ChannelMonitorRunner) fire(ctx context.Context, task *scheduledMonitor) {
+ if r.settingService != nil && !r.settingService.GetChannelMonitorRuntime(ctx).Enabled {
+ return
+ }
+ if !r.tryAcquireInFlight(task.id) {
+ slog.Debug("channel_monitor: skip already in-flight",
+ "monitor_id", task.id, "name", task.name)
+ return
+ }
+ if _, ok := r.pool.TrySubmit(func() {
+ r.runOne(task.id, task.name)
+ }); !ok {
+ // 池满:丢弃本次检测,但必须释放已占用的 inFlight 槽,否则该 monitor 会被永久卡住。
+ r.releaseInFlight(task.id)
+ slog.Warn("channel_monitor: worker pool full, skip submission",
+ "monitor_id", task.id, "name", task.name)
+ }
+}
+
+// tryAcquireInFlight 原子地占用 monitor 的 in-flight 槽。
+// 已被占用返回 false(调用方应跳过本次提交)。
+func (r *ChannelMonitorRunner) tryAcquireInFlight(id int64) bool {
+ r.inFlightMu.Lock()
+ defer r.inFlightMu.Unlock()
+ if _, exists := r.inFlight[id]; exists {
+ return false
+ }
+ r.inFlight[id] = struct{}{}
+ return true
+}
+
+// releaseInFlight 释放 in-flight 槽。runOne 完成(含 panic recover)后必须调用。
+func (r *ChannelMonitorRunner) releaseInFlight(id int64) {
+ r.inFlightMu.Lock()
+ delete(r.inFlight, id)
+ r.inFlightMu.Unlock()
+}
+
+// runOne 执行单个监控的检测。所有错误只记日志,不熔断。
+// 任务结束时(含 panic recover)必须释放 in-flight 槽。
+func (r *ChannelMonitorRunner) runOne(id int64, name string) {
+ ctx, cancel := context.WithTimeout(context.Background(), monitorRequestTimeout+monitorPingTimeout+monitorRunOneBuffer)
+ defer cancel()
+
+ defer r.releaseInFlight(id)
+
+ defer func() {
+ if rec := recover(); rec != nil {
+ slog.Error("channel_monitor: runner panic",
+ "monitor_id", id, "name", name, "panic", rec)
+ }
+ }()
+
+ if _, err := r.svc.RunCheck(ctx, id); err != nil {
+ slog.Warn("channel_monitor: run check failed",
+ "monitor_id", id, "name", name, "error", err)
+ }
+}
diff --git a/backend/internal/service/channel_monitor_runner_test.go b/backend/internal/service/channel_monitor_runner_test.go
new file mode 100644
index 00000000..5eed3c20
--- /dev/null
+++ b/backend/internal/service/channel_monitor_runner_test.go
@@ -0,0 +1,277 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+// stubMonitorSvc 实现 monitorRunnerSvc,用于隔离 runner 与真实 service/repo。
+type stubMonitorSvc struct {
+ enabled []*ChannelMonitor
+ runCount atomic.Int64
+ runCalled chan int64 // 每次 RunCheck 触发时 push 一次(缓冲足够大避免阻塞)
+ runErr error
+ listErr error
+ runHoldFor time.Duration // RunCheck 内额外阻塞的时长,用来测试 Stop 等待行为
+}
+
+func (s *stubMonitorSvc) ListEnabledMonitors(_ context.Context) ([]*ChannelMonitor, error) {
+ if s.listErr != nil {
+ return nil, s.listErr
+ }
+ return s.enabled, nil
+}
+
+func (s *stubMonitorSvc) RunCheck(ctx context.Context, id int64) ([]*CheckResult, error) {
+ s.runCount.Add(1)
+ if s.runCalled != nil {
+ select {
+ case s.runCalled <- id:
+ default:
+ }
+ }
+ if s.runHoldFor > 0 {
+ select {
+ case <-time.After(s.runHoldFor):
+ case <-ctx.Done():
+ }
+ }
+ return nil, s.runErr
+}
+
+func newRunnerForTest(svc monitorRunnerSvc) *ChannelMonitorRunner {
+ return newChannelMonitorRunner(svc, nil)
+}
+
+// 等待 condition 在 timeout 内变 true,否则 t.Fatalf。轮询 5ms 一次。
+func waitFor(t *testing.T, timeout time.Duration, msg string, cond func() bool) {
+ t.Helper()
+ deadline := time.Now().Add(timeout)
+ for time.Now().Before(deadline) {
+ if cond() {
+ return
+ }
+ time.Sleep(5 * time.Millisecond)
+ }
+ if !cond() {
+ t.Fatalf("waitFor timed out: %s", msg)
+ }
+}
+
+func runnerTaskCount(r *ChannelMonitorRunner) int {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ return len(r.tasks)
+}
+
+func runnerTaskPtr(r *ChannelMonitorRunner, id int64) *scheduledMonitor {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ return r.tasks[id]
+}
+
+// TestSchedule_AddsTaskAndFiresOnce 验证 Schedule 后立即触发一次首检测,并把任务记入 tasks 表。
+func TestSchedule_AddsTaskAndFiresOnce(t *testing.T) {
+ svc := &stubMonitorSvc{runCalled: make(chan int64, 4)}
+ r := newRunnerForTest(svc)
+ r.Start() // svc.enabled 为空,Start 立即完成
+
+ r.Schedule(&ChannelMonitor{ID: 1, Name: "m1", Enabled: true, IntervalSeconds: 60})
+
+ if got := runnerTaskCount(r); got != 1 {
+ t.Fatalf("expected 1 scheduled task, got %d", got)
+ }
+
+ select {
+ case id := <-svc.runCalled:
+ if id != 1 {
+ t.Fatalf("expected first fire for id=1, got %d", id)
+ }
+ case <-time.After(2 * time.Second):
+ t.Fatal("expected immediate first fire within 2s")
+ }
+
+ r.Stop()
+}
+
+// TestSchedule_ReplaceCancelsOldTask 验证对同一 id 二次 Schedule 会替换旧 task 实例。
+// (旧 goroutine 通过 ctx 取消退出;这里以 task 指针不同 + Stop 不超时作为证据。)
+func TestSchedule_ReplaceCancelsOldTask(t *testing.T) {
+ svc := &stubMonitorSvc{runCalled: make(chan int64, 8)}
+ r := newRunnerForTest(svc)
+ r.Start()
+
+ m := &ChannelMonitor{ID: 7, Name: "m7", Enabled: true, IntervalSeconds: 60}
+ r.Schedule(m)
+ first := runnerTaskPtr(r, 7)
+ if first == nil {
+ t.Fatal("first schedule did not register task")
+ }
+
+ r.Schedule(m)
+ second := runnerTaskPtr(r, 7)
+ if second == nil {
+ t.Fatal("second schedule did not register task")
+ }
+ if first == second {
+ t.Fatal("re-Schedule should create a new scheduledMonitor instance")
+ }
+
+ stoppedWithin(t, r, 3*time.Second)
+}
+
+// TestUnschedule_RemovesTask 验证 Unschedule 删除 task 并使对应 goroutine 退出。
+func TestUnschedule_RemovesTask(t *testing.T) {
+ svc := &stubMonitorSvc{runCalled: make(chan int64, 4)}
+ r := newRunnerForTest(svc)
+ r.Start()
+
+ r.Schedule(&ChannelMonitor{ID: 3, Enabled: true, IntervalSeconds: 60})
+ waitFor(t, time.Second, "task registered", func() bool { return runnerTaskCount(r) == 1 })
+
+ r.Unschedule(3)
+ if got := runnerTaskCount(r); got != 0 {
+ t.Fatalf("expected tasks empty after Unschedule, got %d", got)
+ }
+
+ stoppedWithin(t, r, 3*time.Second)
+}
+
+// TestSchedule_DisabledRedirectsToUnschedule 验证 Enabled=false 等同于 Unschedule。
+func TestSchedule_DisabledRedirectsToUnschedule(t *testing.T) {
+ svc := &stubMonitorSvc{runCalled: make(chan int64, 4)}
+ r := newRunnerForTest(svc)
+ r.Start()
+
+ r.Schedule(&ChannelMonitor{ID: 9, Enabled: true, IntervalSeconds: 60})
+ waitFor(t, time.Second, "task registered", func() bool { return runnerTaskCount(r) == 1 })
+
+ r.Schedule(&ChannelMonitor{ID: 9, Enabled: false, IntervalSeconds: 60})
+ if got := runnerTaskCount(r); got != 0 {
+ t.Fatalf("expected tasks empty after disabled re-Schedule, got %d", got)
+ }
+
+ stoppedWithin(t, r, 3*time.Second)
+}
+
+// TestSchedule_InvalidIntervalSkipped 验证 IntervalSeconds<=0 不会注册任务(防御性检查)。
+func TestSchedule_InvalidIntervalSkipped(t *testing.T) {
+ svc := &stubMonitorSvc{}
+ r := newRunnerForTest(svc)
+ r.Start()
+
+ r.Schedule(&ChannelMonitor{ID: 1, Enabled: true, IntervalSeconds: 0})
+ if got := runnerTaskCount(r); got != 0 {
+ t.Fatalf("expected no task for invalid interval, got %d", got)
+ }
+ r.Stop()
+}
+
+// TestSchedule_BeforeStartIsNoOp 验证 Start 之前调用 Schedule 不会注册任务。
+func TestSchedule_BeforeStartIsNoOp(t *testing.T) {
+ svc := &stubMonitorSvc{}
+ r := newRunnerForTest(svc)
+ // 故意不调用 Start
+
+ r.Schedule(&ChannelMonitor{ID: 1, Enabled: true, IntervalSeconds: 60})
+ if got := runnerTaskCount(r); got != 0 {
+ t.Fatalf("expected no task before Start, got %d", got)
+ }
+ r.Stop()
+}
+
+// TestStart_LoadsAllEnabledMonitors 验证 Start 会为 ListEnabledMonitors 返回的每条记录建立任务。
+func TestStart_LoadsAllEnabledMonitors(t *testing.T) {
+ svc := &stubMonitorSvc{
+ enabled: []*ChannelMonitor{
+ {ID: 1, Enabled: true, IntervalSeconds: 60},
+ {ID: 2, Enabled: true, IntervalSeconds: 60},
+ {ID: 3, Enabled: true, IntervalSeconds: 60},
+ },
+ }
+ r := newRunnerForTest(svc)
+ r.Start()
+ waitFor(t, 2*time.Second, "all 3 tasks scheduled", func() bool { return runnerTaskCount(r) == 3 })
+
+ stoppedWithin(t, r, 3*time.Second)
+}
+
+// TestStop_DrainsAllGoroutines 验证 Stop 会等待所有调度 goroutine 退出(无游离)。
+func TestStop_DrainsAllGoroutines(t *testing.T) {
+ svc := &stubMonitorSvc{}
+ r := newRunnerForTest(svc)
+ r.Start()
+
+ for id := int64(1); id <= 5; id++ {
+ r.Schedule(&ChannelMonitor{ID: id, Enabled: true, IntervalSeconds: 60})
+ }
+ waitFor(t, 2*time.Second, "5 tasks scheduled", func() bool { return runnerTaskCount(r) == 5 })
+
+ stoppedWithin(t, r, 3*time.Second)
+}
+
+// TestStop_WaitsForInFlightCheck 验证 Stop 会等待正在执行的 RunCheck 退出(pool.StopAndWait)。
+func TestStop_WaitsForInFlightCheck(t *testing.T) {
+ svc := &stubMonitorSvc{
+ runCalled: make(chan int64, 1),
+ runHoldFor: 200 * time.Millisecond,
+ }
+ r := newRunnerForTest(svc)
+ r.Start()
+ r.Schedule(&ChannelMonitor{ID: 1, Enabled: true, IntervalSeconds: 60})
+
+ select {
+ case <-svc.runCalled:
+ case <-time.After(2 * time.Second):
+ t.Fatal("first fire never happened")
+ }
+
+ start := time.Now()
+ stoppedWithin(t, r, 3*time.Second)
+ elapsed := time.Since(start)
+ // Stop 必须等待 in-flight check 跑完(runHoldFor=200ms),耗时下界约 100ms。
+ if elapsed < 100*time.Millisecond {
+ t.Fatalf("Stop returned too fast (%v); did not wait for in-flight check", elapsed)
+ }
+}
+
+// TestInFlight_PoolFullReleasesSlot 直接驱动 fire 路径,模拟 pool.TrySubmit 失败时 inFlight 必须释放。
+// 用一个小型 stub pool 替换 r.pool 不便(pond.Pool 是接口但 mock 麻烦),
+// 改为:占满 inFlight 后直接 fire,验证不会在 inFlight 空槽时永久卡住。
+func TestInFlight_AcquireReleaseSymmetric(t *testing.T) {
+ svc := &stubMonitorSvc{}
+ r := newRunnerForTest(svc)
+
+ if !r.tryAcquireInFlight(42) {
+ t.Fatal("first acquire should succeed")
+ }
+ if r.tryAcquireInFlight(42) {
+ t.Fatal("second acquire (no release) must fail")
+ }
+ r.releaseInFlight(42)
+ if !r.tryAcquireInFlight(42) {
+ t.Fatal("acquire after release should succeed")
+ }
+ r.releaseInFlight(42)
+}
+
+// stoppedWithin 在 timeout 内并行调用 Stop,超时则 Fatal。验证 Stop 不会阻塞。
+func stoppedWithin(t *testing.T, r *ChannelMonitorRunner, timeout time.Duration) {
+ t.Helper()
+ done := make(chan struct{})
+ var once sync.Once
+ go func() {
+ r.Stop()
+ once.Do(func() { close(done) })
+ }()
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Fatalf("Stop did not return within %s — leaked goroutine?", timeout)
+ }
+}
diff --git a/backend/internal/service/channel_monitor_service.go b/backend/internal/service/channel_monitor_service.go
new file mode 100644
index 00000000..7050e141
--- /dev/null
+++ b/backend/internal/service/channel_monitor_service.go
@@ -0,0 +1,539 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "strings"
+ "sync"
+ "time"
+
+ "golang.org/x/sync/errgroup"
+)
+
+// ChannelMonitorRepository 渠道监控数据访问接口。
+// 入参/返回的指针类型均使用 service 包的 ChannelMonitor 模型,
+// repository 实现负责与 ent 模型互转,并保持 api_key_encrypted 字段为密文。
+type ChannelMonitorRepository interface {
+ // CRUD
+ Create(ctx context.Context, m *ChannelMonitor) error
+ GetByID(ctx context.Context, id int64) (*ChannelMonitor, error)
+ Update(ctx context.Context, m *ChannelMonitor) error
+ Delete(ctx context.Context, id int64) error
+ List(ctx context.Context, params ChannelMonitorListParams) ([]*ChannelMonitor, int64, error)
+
+ // 调度器辅助
+ ListEnabled(ctx context.Context) ([]*ChannelMonitor, error)
+ MarkChecked(ctx context.Context, id int64, checkedAt time.Time) error
+ InsertHistoryBatch(ctx context.Context, rows []*ChannelMonitorHistoryRow) error
+ DeleteHistoryBefore(ctx context.Context, before time.Time) (int64, error)
+
+ // 历史记录
+ ListHistory(ctx context.Context, monitorID int64, model string, limit int) ([]*ChannelMonitorHistoryEntry, error)
+
+ // 用户视图聚合
+ ListLatestPerModel(ctx context.Context, monitorID int64) ([]*ChannelMonitorLatest, error)
+ ComputeAvailability(ctx context.Context, monitorID int64, windowDays int) ([]*ChannelMonitorAvailability, error)
+
+ // 批量聚合(admin/user list 用,避免 N+1)
+ ListLatestForMonitorIDs(ctx context.Context, ids []int64) (map[int64][]*ChannelMonitorLatest, error)
+ ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*ChannelMonitorAvailability, error)
+ // ListRecentHistoryForMonitors 批量取多个 monitor 各自主模型(primaryModels[monitorID])最近 perMonitorLimit 条历史。
+ // 返回的 entry 已按 checked_at DESC 排序(最新在前),不含 message 字段。
+ ListRecentHistoryForMonitors(ctx context.Context, ids []int64, primaryModels map[int64]string, perMonitorLimit int) (map[int64][]*ChannelMonitorHistoryEntry, error)
+
+ // ---------- 聚合维护(OpsCleanupService 调用) ----------
+
+ // UpsertDailyRollupsFor 把 targetDate 当天的明细按 (monitor_id, model, bucket_date)
+ // 聚合到 channel_monitor_daily_rollups。targetDate 会被截断到日期;
+ // 用 ON CONFLICT DO UPDATE 实现幂等回填,返回 upsert 影响的行数。
+ UpsertDailyRollupsFor(ctx context.Context, targetDate time.Time) (int64, error)
+ // DeleteRollupsBefore 软删 bucket_date < beforeDate 的聚合行,返回删除行数。
+ DeleteRollupsBefore(ctx context.Context, beforeDate time.Time) (int64, error)
+ // LoadAggregationWatermark 读 watermark(id=1)。
+ // 返回 nil 表示从未聚合过;watermark 表本身预期已存在单行(migration 110 写入)。
+ LoadAggregationWatermark(ctx context.Context) (*time.Time, error)
+ // UpdateAggregationWatermark 写 watermark(UPSERT 到 id=1)。
+ UpdateAggregationWatermark(ctx context.Context, date time.Time) error
+}
+
+// ChannelMonitorService 渠道监控管理服务。
+type ChannelMonitorService struct {
+ repo ChannelMonitorRepository
+ encryptor SecretEncryptor
+ // scheduler 由 wire 通过 SetScheduler 注入;CRUD 后调用对应钩子即时同步任务。
+ // 测试或未注入场景下保持 nil,所有钩子调用变为 no-op。
+ scheduler MonitorScheduler
+}
+
+// NewChannelMonitorService 创建渠道监控服务实例。
+func NewChannelMonitorService(repo ChannelMonitorRepository, encryptor SecretEncryptor) *ChannelMonitorService {
+ return &ChannelMonitorService{repo: repo, encryptor: encryptor}
+}
+
+// ---------- CRUD ----------
+
+// List 列表查询(支持 provider/enabled/search 过滤 + 分页)。
+// 返回的 ChannelMonitor.APIKey 已解密为明文,handler 层负责脱敏。
+func (s *ChannelMonitorService) List(ctx context.Context, params ChannelMonitorListParams) ([]*ChannelMonitor, int64, error) {
+ if params.Page < 1 {
+ params.Page = 1
+ }
+ if params.PageSize < 1 || params.PageSize > 200 {
+ params.PageSize = 20
+ }
+ items, total, err := s.repo.List(ctx, params)
+ if err != nil {
+ return nil, 0, fmt.Errorf("list channel monitors: %w", err)
+ }
+ for _, it := range items {
+ s.decryptInPlace(it)
+ }
+ return items, total, nil
+}
+
+// Get 查询单个监控(解密 API Key)。
+func (s *ChannelMonitorService) Get(ctx context.Context, id int64) (*ChannelMonitor, error) {
+ m, err := s.repo.GetByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ s.decryptInPlace(m)
+ return m, nil
+}
+
+// Create 创建监控(内部加密 api_key)。
+func (s *ChannelMonitorService) Create(ctx context.Context, p ChannelMonitorCreateParams) (*ChannelMonitor, error) {
+ if err := validateCreateParams(p); err != nil {
+ return nil, err
+ }
+ if err := validateBodyModeParams(p.BodyOverrideMode, p.BodyOverride); err != nil {
+ return nil, err
+ }
+ if err := validateExtraHeaders(p.ExtraHeaders); err != nil {
+ return nil, err
+ }
+ encrypted, err := s.encryptor.Encrypt(p.APIKey)
+ if err != nil {
+ return nil, fmt.Errorf("encrypt api key: %w", err)
+ }
+ m := &ChannelMonitor{
+ Name: strings.TrimSpace(p.Name),
+ Provider: p.Provider,
+ Endpoint: normalizeEndpoint(p.Endpoint),
+ APIKey: encrypted, // 注意:传入 repository 时该字段为密文
+ PrimaryModel: strings.TrimSpace(p.PrimaryModel),
+ ExtraModels: normalizeModels(p.ExtraModels),
+ GroupName: strings.TrimSpace(p.GroupName),
+ Enabled: p.Enabled,
+ IntervalSeconds: p.IntervalSeconds,
+ CreatedBy: p.CreatedBy,
+ TemplateID: p.TemplateID,
+ ExtraHeaders: emptyHeadersIfNil(p.ExtraHeaders),
+ BodyOverrideMode: defaultBodyMode(p.BodyOverrideMode),
+ BodyOverride: p.BodyOverride,
+ }
+ if err := s.repo.Create(ctx, m); err != nil {
+ return nil, fmt.Errorf("create channel monitor: %w", err)
+ }
+ // 不再调 s.Get 重走解密链:已知刚加密的明文,直接构造响应。
+ // 这样可避免 SecretEncryptor 解密失败时 APIKey 被静默清空的问题(见 Fix 4)。
+ m.APIKey = strings.TrimSpace(p.APIKey)
+ if s.scheduler != nil {
+ s.scheduler.Schedule(m)
+ }
+ return m, nil
+}
+
+// validateCreateParams 把 Create 入参的所有校验聚拢为一个函数,避免 Create 主体超过 30 行。
+func validateCreateParams(p ChannelMonitorCreateParams) error {
+ if err := validateProvider(p.Provider); err != nil {
+ return err
+ }
+ if err := validateInterval(p.IntervalSeconds); err != nil {
+ return err
+ }
+ if err := validateEndpoint(p.Endpoint); err != nil {
+ return err
+ }
+ if strings.TrimSpace(p.APIKey) == "" {
+ return ErrChannelMonitorMissingAPIKey
+ }
+ if strings.TrimSpace(p.PrimaryModel) == "" {
+ return ErrChannelMonitorMissingPrimaryModel
+ }
+ return nil
+}
+
+// Update 更新监控。APIKey 字段:nil 或空字符串 = 不修改;非空 = 加密后覆盖。
+func (s *ChannelMonitorService) Update(ctx context.Context, id int64, p ChannelMonitorUpdateParams) (*ChannelMonitor, error) {
+ existing, err := s.repo.GetByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ if err := applyMonitorUpdate(existing, p); err != nil {
+ return nil, err
+ }
+
+ newPlainAPIKey, apiKeyUpdated, err := s.applyAPIKeyUpdate(existing, p.APIKey)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := s.repo.Update(ctx, existing); err != nil {
+ return nil, fmt.Errorf("update channel monitor: %w", err)
+ }
+
+ // 不再调 s.Get 重走解密链:避免二次解密带来的"密文被静默清空"风险(与 Create 一致)。
+ if apiKeyUpdated {
+ existing.APIKey = newPlainAPIKey
+ } else {
+ s.decryptInPlace(existing)
+ }
+ if s.scheduler != nil {
+ // Schedule 内部根据 Enabled 自动选择 Unschedule 或重建任务,
+ // IntervalSeconds 变化也会被自然吸收(旧 task 取消 + 新 task 用新 interval)。
+ s.scheduler.Schedule(existing)
+ }
+ return existing, nil
+}
+
+// applyAPIKeyUpdate 处理 Update 中的 APIKey 字段:
+// - 入参 raw 为 nil 或空白:不修改 existing.APIKey(仍为密文),返回 updated=false
+// - 非空:加密后写入 existing.APIKey;同时把明文返回给调用方,
+// 供写库成功后塞回 existing 避免把密文吐回客户端
+func (s *ChannelMonitorService) applyAPIKeyUpdate(existing *ChannelMonitor, raw *string) (plain string, updated bool, err error) {
+ if raw == nil || strings.TrimSpace(*raw) == "" {
+ return "", false, nil
+ }
+ plain = strings.TrimSpace(*raw)
+ encrypted, encErr := s.encryptor.Encrypt(plain)
+ if encErr != nil {
+ return "", false, fmt.Errorf("encrypt api key: %w", encErr)
+ }
+ existing.APIKey = encrypted
+ return plain, true, nil
+}
+
+// Delete 删除监控(历史通过外键 CASCADE 自动清理)。
+func (s *ChannelMonitorService) Delete(ctx context.Context, id int64) error {
+ if err := s.repo.Delete(ctx, id); err != nil {
+ return fmt.Errorf("delete channel monitor: %w", err)
+ }
+ if s.scheduler != nil {
+ s.scheduler.Unschedule(id)
+ }
+ return nil
+}
+
+// ListHistory 列出某个监控最近的检测历史。
+// model 为空表示返回所有模型;limit <= 0 时使用默认值,超过上限会被截断。
+func (s *ChannelMonitorService) ListHistory(ctx context.Context, id int64, model string, limit int) ([]*ChannelMonitorHistoryEntry, error) {
+ if _, err := s.repo.GetByID(ctx, id); err != nil {
+ return nil, err
+ }
+ if limit <= 0 {
+ limit = MonitorHistoryDefaultLimit
+ }
+ if limit > MonitorHistoryMaxLimit {
+ limit = MonitorHistoryMaxLimit
+ }
+ entries, err := s.repo.ListHistory(ctx, id, strings.TrimSpace(model), limit)
+ if err != nil {
+ return nil, fmt.Errorf("list history: %w", err)
+ }
+ return entries, nil
+}
+
+// ---------- 业务 ----------
+
+// RunCheck 同步触发对一个监控的检测:并发跑 primary + extra 模型,
+// 写历史记录并更新 last_checked_at。返回每个模型的检测结果。
+func (s *ChannelMonitorService) RunCheck(ctx context.Context, id int64) ([]*CheckResult, error) {
+ m, err := s.Get(ctx, id) // 已解密 APIKey
+ if err != nil {
+ return nil, err
+ }
+ if m.APIKeyDecryptFailed {
+ return nil, ErrChannelMonitorAPIKeyDecryptFailed
+ }
+ results := s.runChecksConcurrent(ctx, m)
+ s.persistCheckResults(ctx, m, results)
+ return results, nil
+}
+
+// persistCheckResults 写入本次检测的历史记录并更新 last_checked_at。
+// 任一写库失败都只记日志,不影响调用方拿到 results(与 MVP 期望一致:宁可漏记历史也要先返回结果)。
+func (s *ChannelMonitorService) persistCheckResults(ctx context.Context, m *ChannelMonitor, results []*CheckResult) {
+ rows := make([]*ChannelMonitorHistoryRow, 0, len(results))
+ for _, r := range results {
+ rows = append(rows, &ChannelMonitorHistoryRow{
+ MonitorID: m.ID,
+ Model: r.Model,
+ Status: r.Status,
+ LatencyMs: r.LatencyMs,
+ PingLatencyMs: r.PingLatencyMs,
+ Message: r.Message,
+ CheckedAt: r.CheckedAt,
+ })
+ }
+ if err := s.repo.InsertHistoryBatch(ctx, rows); err != nil {
+ slog.Error("channel_monitor: insert history failed",
+ "monitor_id", m.ID, "name", m.Name, "error", err)
+ }
+ if err := s.repo.MarkChecked(ctx, m.ID, time.Now()); err != nil {
+ slog.Error("channel_monitor: mark checked failed",
+ "monitor_id", m.ID, "error", err)
+ }
+}
+
+// runChecksConcurrent 对 primary + extra 模型并发执行检测。
+// errgroup 仅用于等待,不传播错误(每个 model 失败都已打包进 CheckResult)。
+func (s *ChannelMonitorService) runChecksConcurrent(ctx context.Context, m *ChannelMonitor) []*CheckResult {
+ models := append([]string{m.PrimaryModel}, m.ExtraModels...)
+ results := make([]*CheckResult, len(models))
+
+ // ping 共享一次,所有模型记录同一个 ping 延迟。
+ pingMs := pingEndpointOrigin(ctx, m.Endpoint)
+
+ // 所有模型共用同一份 CheckOptions(来自监控的快照字段)。
+ opts := &CheckOptions{
+ ExtraHeaders: m.ExtraHeaders,
+ BodyOverrideMode: m.BodyOverrideMode,
+ BodyOverride: m.BodyOverride,
+ }
+
+ var eg errgroup.Group
+ var mu sync.Mutex
+ for i, model := range models {
+ i, model := i, model
+ eg.Go(func() error {
+ r := runCheckForModel(ctx, m.Provider, m.Endpoint, m.APIKey, model, opts)
+ r.PingLatencyMs = pingMs
+ mu.Lock()
+ results[i] = r
+ mu.Unlock()
+ return nil
+ })
+ }
+ _ = eg.Wait()
+ return results
+}
+
+// ---------- 调度器协作 ----------
+
+// SetScheduler 由 wire 在 runner 构造后注入,用于在 CRUD 时即时同步任务表。
+// 通过 setter 注入避免 service ↔ runner 的依赖环。
+func (s *ChannelMonitorService) SetScheduler(sched MonitorScheduler) {
+ s.scheduler = sched
+}
+
+// ListEnabledMonitors 返回所有 enabled=true 的监控(解密后),供 runner 启动时建立任务表。
+func (s *ChannelMonitorService) ListEnabledMonitors(ctx context.Context) ([]*ChannelMonitor, error) {
+ all, err := s.repo.ListEnabled(ctx)
+ if err != nil {
+ return nil, err
+ }
+ for _, m := range all {
+ s.decryptInPlace(m)
+ }
+ return all, nil
+}
+
+// cleanupOldHistory 删除 monitorHistoryRetentionDays 天之前的明细历史记录。
+// 由 RunDailyMaintenance 调用;SoftDeleteMixin 自动把 DELETE 改为 UPDATE deleted_at。
+func (s *ChannelMonitorService) cleanupOldHistory(ctx context.Context) error {
+ before := time.Now().UTC().AddDate(0, 0, -monitorHistoryRetentionDays)
+ deleted, err := s.repo.DeleteHistoryBefore(ctx, before)
+ if err != nil {
+ return fmt.Errorf("delete history before %s: %w", before.Format(time.RFC3339), err)
+ }
+ if deleted > 0 {
+ slog.Info("channel_monitor: history cleanup",
+ "deleted_rows", deleted, "before", before.Format(time.RFC3339))
+ }
+ return nil
+}
+
+// RunDailyMaintenance 每日维护任务:聚合昨天之前未聚合的明细,软删过期明细和聚合。
+// 由 OpsCleanupService 的 cron 调度触发(共享 schedule 和 leader lock)。
+//
+// 幂等性:
+// - watermark 保证已聚合的日期不会重复处理;
+// - UpsertDailyRollupsFor 内部使用 ON CONFLICT DO UPDATE,同一日重复跑结果一致。
+//
+// 每一步失败都只记 slog.Warn,整体函数始终返回 nil 让后续步骤能继续跑
+// (与 OpsCleanupService.runCleanupOnce 风格一致)。
+func (s *ChannelMonitorService) RunDailyMaintenance(ctx context.Context) error {
+ now := time.Now().UTC()
+ today := now.Truncate(24 * time.Hour)
+
+ if err := s.runDailyAggregation(ctx, today); err != nil {
+ slog.Warn("channel_monitor: maintenance step failed",
+ "step", "aggregate", "error", err)
+ }
+ if err := s.cleanupOldHistory(ctx); err != nil {
+ slog.Warn("channel_monitor: maintenance step failed",
+ "step", "prune_history", "error", err)
+ }
+ if err := s.cleanupOldRollups(ctx, today); err != nil {
+ slog.Warn("channel_monitor: maintenance step failed",
+ "step", "prune_rollups", "error", err)
+ }
+ return nil
+}
+
+// runDailyAggregation 从 watermark+1 聚合到昨天(UTC)。
+// 首次跑(watermark nil):从 today-monitorRollupRetentionDays 开始回填。
+// 每次最多聚合 monitorMaintenanceMaxDaysPerRun 天,避免长事务。
+func (s *ChannelMonitorService) runDailyAggregation(ctx context.Context, today time.Time) error {
+ watermark, err := s.repo.LoadAggregationWatermark(ctx)
+ if err != nil {
+ return fmt.Errorf("load watermark: %w", err)
+ }
+
+ start := s.resolveAggregationStart(watermark, today)
+ if !start.Before(today) {
+ return nil // 没有需要聚合的日期
+ }
+
+ iterations := 0
+ for d := start; d.Before(today); d = d.Add(24 * time.Hour) {
+ if iterations >= monitorMaintenanceMaxDaysPerRun {
+ slog.Info("channel_monitor: maintenance aggregation capped",
+ "max_days", monitorMaintenanceMaxDaysPerRun,
+ "next_resume", d.Format("2006-01-02"))
+ break
+ }
+ affected, upErr := s.repo.UpsertDailyRollupsFor(ctx, d)
+ if upErr != nil {
+ return fmt.Errorf("upsert rollups for %s: %w", d.Format("2006-01-02"), upErr)
+ }
+ if err := s.repo.UpdateAggregationWatermark(ctx, d); err != nil {
+ return fmt.Errorf("update watermark to %s: %w", d.Format("2006-01-02"), err)
+ }
+ slog.Info("channel_monitor: rollups upserted",
+ "date", d.Format("2006-01-02"), "affected_rows", affected)
+ iterations++
+ }
+ return nil
+}
+
+// resolveAggregationStart 计算本次聚合起点:
+// - watermark == nil:today - monitorRollupRetentionDays(首次回填最多 30 天)
+// - watermark != nil:*watermark + 1 day
+func (s *ChannelMonitorService) resolveAggregationStart(watermark *time.Time, today time.Time) time.Time {
+ if watermark == nil {
+ return today.AddDate(0, 0, -monitorRollupRetentionDays)
+ }
+ return watermark.UTC().Truncate(24 * time.Hour).Add(24 * time.Hour)
+}
+
+// cleanupOldRollups 软删 bucket_date < today - monitorRollupRetentionDays 的日聚合行。
+func (s *ChannelMonitorService) cleanupOldRollups(ctx context.Context, today time.Time) error {
+ cutoff := today.AddDate(0, 0, -monitorRollupRetentionDays)
+ deleted, err := s.repo.DeleteRollupsBefore(ctx, cutoff)
+ if err != nil {
+ return fmt.Errorf("delete rollups before %s: %w", cutoff.Format("2006-01-02"), err)
+ }
+ if deleted > 0 {
+ slog.Info("channel_monitor: rollups cleanup",
+ "deleted_rows", deleted, "before", cutoff.Format("2006-01-02"))
+ }
+ return nil
+}
+
+// ---------- helpers ----------
+
+// decryptInPlace 把 ChannelMonitor.APIKey 从密文解密为明文。
+// 解密失败时把字段清空 + 设置 APIKeyDecryptFailed=true(不返回错误,避免阻断列表渲染)。
+// runner / RunCheck 必须读取该标志位并拒绝执行检测。
+func (s *ChannelMonitorService) decryptInPlace(m *ChannelMonitor) {
+ if m == nil || m.APIKey == "" {
+ return
+ }
+ plain, err := s.encryptor.Decrypt(m.APIKey)
+ if err != nil {
+ slog.Warn("channel_monitor: decrypt api key failed",
+ "monitor_id", m.ID, "error", err)
+ m.APIKey = ""
+ m.APIKeyDecryptFailed = true
+ return
+ }
+ m.APIKey = plain
+}
+
+// applyMonitorUpdate 把 update params 中非 nil 的字段应用到 existing 上。
+// APIKey 字段在调用方单独处理(涉及加密)。
+//
+// 行数稍超过 30:这是逐字段平铺的 dispatcher,每个 if 都是 1-3 行的"非 nil 则覆盖"模式,
+// 拆分反而会增加跳转噪音、影响可读性,故保留为单函数。
+func applyMonitorUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams) error {
+ if p.Name != nil {
+ existing.Name = strings.TrimSpace(*p.Name)
+ }
+ if p.Provider != nil {
+ if err := validateProvider(*p.Provider); err != nil {
+ return err
+ }
+ existing.Provider = *p.Provider
+ }
+ if p.Endpoint != nil {
+ if err := validateEndpoint(*p.Endpoint); err != nil {
+ return err
+ }
+ existing.Endpoint = normalizeEndpoint(*p.Endpoint)
+ }
+ if p.PrimaryModel != nil {
+ existing.PrimaryModel = strings.TrimSpace(*p.PrimaryModel)
+ }
+ if p.ExtraModels != nil {
+ existing.ExtraModels = normalizeModels(*p.ExtraModels)
+ }
+ if p.GroupName != nil {
+ existing.GroupName = strings.TrimSpace(*p.GroupName)
+ }
+ if p.Enabled != nil {
+ existing.Enabled = *p.Enabled
+ }
+ if p.IntervalSeconds != nil {
+ if err := validateInterval(*p.IntervalSeconds); err != nil {
+ return err
+ }
+ existing.IntervalSeconds = *p.IntervalSeconds
+ }
+ return applyMonitorAdvancedUpdate(existing, p)
+}
+
+// applyMonitorAdvancedUpdate 处理自定义请求快照相关字段,从 applyMonitorUpdate 拆出避免过长。
+func applyMonitorAdvancedUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams) error {
+ if p.ClearTemplate {
+ existing.TemplateID = nil
+ } else if p.TemplateID != nil {
+ id := *p.TemplateID
+ existing.TemplateID = &id
+ }
+ if p.ExtraHeaders != nil {
+ if err := validateExtraHeaders(*p.ExtraHeaders); err != nil {
+ return err
+ }
+ existing.ExtraHeaders = emptyHeadersIfNil(*p.ExtraHeaders)
+ }
+ // BodyOverrideMode / BodyOverride 联合校验,和模板一致。
+ newMode := existing.BodyOverrideMode
+ newBody := existing.BodyOverride
+ if p.BodyOverrideMode != nil {
+ newMode = *p.BodyOverrideMode
+ }
+ if p.BodyOverride != nil {
+ newBody = *p.BodyOverride
+ }
+ if p.BodyOverrideMode != nil || p.BodyOverride != nil {
+ if err := validateBodyModeParams(newMode, newBody); err != nil {
+ return err
+ }
+ existing.BodyOverrideMode = defaultBodyMode(newMode)
+ existing.BodyOverride = newBody
+ }
+ return nil
+}
diff --git a/backend/internal/service/channel_monitor_ssrf.go b/backend/internal/service/channel_monitor_ssrf.go
new file mode 100644
index 00000000..8d93f600
--- /dev/null
+++ b/backend/internal/service/channel_monitor_ssrf.go
@@ -0,0 +1,152 @@
+package service
+
+import (
+ "context"
+ "net"
+ "strings"
+)
+
+// SSRF 防护 helper:
+// - validateEndpoint 在 admin 提交时阻止 http/loopback/私网/云元数据 URL
+// - safeDialContext 在 socket 层再次校验真实 IP,防止 DNS rebinding
+//
+// 已知 cloud metadata hostname 拒绝列表(小写比较)。
+var monitorBlockedHostnames = map[string]struct{}{
+ "localhost": {},
+ "localhost.localdomain": {},
+ "metadata": {},
+ "metadata.google.internal": {},
+ "metadata.goog": {},
+ "instance-data": {},
+ "instance-data.ec2.internal": {},
+}
+
+// CIDR 列表:包含所有需要拒绝的 IPv4/IPv6 段。
+// 解析时只 panic 一次(启动时确认),生产路径只做 Contains。
+var monitorBlockedCIDRs = mustParseCIDRs([]string{
+ "127.0.0.0/8", // IPv4 loopback
+ "10.0.0.0/8", // RFC1918
+ "172.16.0.0/12", // RFC1918
+ "192.168.0.0/16", // RFC1918
+ "169.254.0.0/16", // link-local(含云元数据 169.254.169.254)
+ "100.64.0.0/10", // CGNAT
+ "0.0.0.0/8", // "this network"
+ "::1/128", // IPv6 loopback
+ "fc00::/7", // IPv6 ULA
+ "fe80::/10", // IPv6 link-local
+ "::/128", // IPv6 unspecified
+})
+
+// monitorDialer 共享 Dialer,与 net/http 默认值对齐。
+var monitorDialer = &net.Dialer{
+ Timeout: monitorDialTimeout,
+ KeepAlive: monitorDialKeepAlive,
+}
+
+// mustParseCIDRs 在包初始化时解析 CIDR 字符串,失败 panic。
+func mustParseCIDRs(cidrs []string) []*net.IPNet {
+ out := make([]*net.IPNet, 0, len(cidrs))
+ for _, c := range cidrs {
+ _, n, err := net.ParseCIDR(c)
+ if err != nil {
+ panic("channel_monitor_ssrf: invalid CIDR " + c + ": " + err.Error())
+ }
+ out = append(out, n)
+ }
+ return out
+}
+
+// isBlockedHostname 判断 hostname 是否命中黑名单。
+func isBlockedHostname(hostname string) bool {
+ if hostname == "" {
+ return true
+ }
+ _, blocked := monitorBlockedHostnames[strings.ToLower(hostname)]
+ return blocked
+}
+
+// isPrivateIP 判断 IP 是否落在禁止段(loopback/RFC1918/link-local/ULA 等)。
+func isPrivateIP(ip net.IP) bool {
+ if ip == nil {
+ return true
+ }
+ if ip.IsUnspecified() || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
+ return true
+ }
+ for _, n := range monitorBlockedCIDRs {
+ if n.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+// isPrivateOrLoopbackHost 解析 hostname 的所有 A/AAAA 记录,
+// 任一 IP 落在私网/loopback 段即认为不安全。
+//
+// hostname 是 IP 字面量时也走同一路径。
+func isPrivateOrLoopbackHost(ctx context.Context, hostname string) (bool, error) {
+ if isBlockedHostname(hostname) {
+ return true, nil
+ }
+ // IP 字面量直接判断。
+ if ip := net.ParseIP(hostname); ip != nil {
+ return isPrivateIP(ip), nil
+ }
+ resolver := net.DefaultResolver
+ addrs, err := resolver.LookupIPAddr(ctx, hostname)
+ if err != nil {
+ return false, err
+ }
+ if len(addrs) == 0 {
+ return true, nil
+ }
+ for _, a := range addrs {
+ if isPrivateIP(a.IP) {
+ return true, nil
+ }
+ }
+ return false, nil
+}
+
+// safeDialContext 在真实 dial 前再次校验目标 IP,防止 DNS rebinding。
+// 解析 hostname 后逐个 IP 尝试连接,命中私网即拒绝(即便 validateEndpoint 时返回的是公网 IP)。
+func safeDialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return nil, err
+ }
+ // 字面量 IP 走快速路径。
+ if ip := net.ParseIP(host); ip != nil {
+ if isPrivateIP(ip) {
+ return nil, &net.AddrError{Err: "blocked by SSRF policy", Addr: address}
+ }
+ return monitorDialer.DialContext(ctx, network, address)
+ }
+ if isBlockedHostname(host) {
+ return nil, &net.AddrError{Err: "blocked by SSRF policy", Addr: address}
+ }
+ addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
+ if err != nil {
+ return nil, err
+ }
+ if len(addrs) == 0 {
+ return nil, &net.AddrError{Err: "no addresses for host", Addr: host}
+ }
+ var lastErr error
+ for _, a := range addrs {
+ if isPrivateIP(a.IP) {
+ lastErr = &net.AddrError{Err: "blocked by SSRF policy", Addr: a.IP.String()}
+ continue
+ }
+ conn, err := monitorDialer.DialContext(ctx, network, net.JoinHostPort(a.IP.String(), port))
+ if err == nil {
+ return conn, nil
+ }
+ lastErr = err
+ }
+ if lastErr == nil {
+ lastErr = &net.AddrError{Err: "no usable addresses", Addr: host}
+ }
+ return nil, lastErr
+}
diff --git a/backend/internal/service/channel_monitor_template_service.go b/backend/internal/service/channel_monitor_template_service.go
new file mode 100644
index 00000000..8d2e8173
--- /dev/null
+++ b/backend/internal/service/channel_monitor_template_service.go
@@ -0,0 +1,251 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "regexp"
+ "strings"
+)
+
+// ChannelMonitorRequestTemplateRepository 模板数据访问接口。
+type ChannelMonitorRequestTemplateRepository interface {
+ Create(ctx context.Context, t *ChannelMonitorRequestTemplate) error
+ GetByID(ctx context.Context, id int64) (*ChannelMonitorRequestTemplate, error)
+ Update(ctx context.Context, t *ChannelMonitorRequestTemplate) error
+ Delete(ctx context.Context, id int64) error
+ List(ctx context.Context, params ChannelMonitorRequestTemplateListParams) ([]*ChannelMonitorRequestTemplate, error)
+ // ApplyToMonitors 把模板当前的 extra_headers / body_override_mode / body_override
+ // 批量覆盖到指定 monitorIDs 的监控上(同时还要求这些监控当前 template_id = id,
+ // 防止误覆盖未关联的监控)。monitorIDs 必须非空;空列表直接返回 0 不写库。
+ // 返回被覆盖的监控数量。
+ ApplyToMonitors(ctx context.Context, id int64, monitorIDs []int64) (int64, error)
+ // CountAssociatedMonitors 统计 template_id = id 的监控数(用于 UI 展示「应用到 N 个配置」)。
+ CountAssociatedMonitors(ctx context.Context, id int64) (int64, error)
+ // ListAssociatedMonitors 列出所有 template_id = id 的监控简略信息(id/name/provider/enabled)
+ // 给 apply picker UI 用,避免前端再做一次 list+filter。
+ ListAssociatedMonitors(ctx context.Context, id int64) ([]*AssociatedMonitorBrief, error)
+}
+
+// AssociatedMonitorBrief 模板关联监控的简略信息(picker / 列表展示用)。
+type AssociatedMonitorBrief struct {
+ ID int64
+ Name string
+ Provider string
+ Enabled bool
+}
+
+// ChannelMonitorRequestTemplateService 模板管理 service。
+type ChannelMonitorRequestTemplateService struct {
+ repo ChannelMonitorRequestTemplateRepository
+}
+
+// NewChannelMonitorRequestTemplateService 创建模板 service。
+func NewChannelMonitorRequestTemplateService(repo ChannelMonitorRequestTemplateRepository) *ChannelMonitorRequestTemplateService {
+ return &ChannelMonitorRequestTemplateService{repo: repo}
+}
+
+// ---------- CRUD ----------
+
+// List 按 provider 过滤(空串 = 全部),不分页(模板量级小)。
+func (s *ChannelMonitorRequestTemplateService) List(ctx context.Context, params ChannelMonitorRequestTemplateListParams) ([]*ChannelMonitorRequestTemplate, error) {
+ if params.Provider != "" {
+ if err := validateProvider(params.Provider); err != nil {
+ return nil, err
+ }
+ }
+ return s.repo.List(ctx, params)
+}
+
+// Get 返回单个模板。
+func (s *ChannelMonitorRequestTemplateService) Get(ctx context.Context, id int64) (*ChannelMonitorRequestTemplate, error) {
+ return s.repo.GetByID(ctx, id)
+}
+
+// Create 创建模板(会校验 headers 黑名单和 body 模式匹配)。
+func (s *ChannelMonitorRequestTemplateService) Create(ctx context.Context, p ChannelMonitorRequestTemplateCreateParams) (*ChannelMonitorRequestTemplate, error) {
+ if err := validateTemplateCreateParams(p); err != nil {
+ return nil, err
+ }
+ t := &ChannelMonitorRequestTemplate{
+ Name: strings.TrimSpace(p.Name),
+ Provider: p.Provider,
+ Description: strings.TrimSpace(p.Description),
+ ExtraHeaders: emptyHeadersIfNil(p.ExtraHeaders),
+ BodyOverrideMode: defaultBodyMode(p.BodyOverrideMode),
+ BodyOverride: p.BodyOverride,
+ }
+ if err := s.repo.Create(ctx, t); err != nil {
+ return nil, fmt.Errorf("create template: %w", err)
+ }
+ return t, nil
+}
+
+// Update 更新模板(provider 不可改)。
+func (s *ChannelMonitorRequestTemplateService) Update(ctx context.Context, id int64, p ChannelMonitorRequestTemplateUpdateParams) (*ChannelMonitorRequestTemplate, error) {
+ existing, err := s.repo.GetByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ if err := applyTemplateUpdate(existing, p); err != nil {
+ return nil, err
+ }
+ if err := s.repo.Update(ctx, existing); err != nil {
+ return nil, fmt.Errorf("update template: %w", err)
+ }
+ return existing, nil
+}
+
+// Delete 删除模板。关联监控的 template_id 会被 SET NULL,监控保留快照继续跑。
+func (s *ChannelMonitorRequestTemplateService) Delete(ctx context.Context, id int64) error {
+ if err := s.repo.Delete(ctx, id); err != nil {
+ return fmt.Errorf("delete template: %w", err)
+ }
+ return nil
+}
+
+// ApplyToMonitors 把模板当前配置应用到 monitorIDs 列表里的关联监控。
+// monitorIDs 必须非空且每个 id 都必须当前 template_id = id;不满足条件的会被 SQL WHERE 过滤掉。
+// 返回实际被覆盖的监控数。
+func (s *ChannelMonitorRequestTemplateService) ApplyToMonitors(ctx context.Context, id int64, monitorIDs []int64) (int64, error) {
+ if _, err := s.repo.GetByID(ctx, id); err != nil {
+ return 0, err
+ }
+ if len(monitorIDs) == 0 {
+ return 0, ErrChannelMonitorTemplateApplyEmpty
+ }
+ affected, err := s.repo.ApplyToMonitors(ctx, id, monitorIDs)
+ if err != nil {
+ return 0, fmt.Errorf("apply template to monitors: %w", err)
+ }
+ return affected, nil
+}
+
+// CountAssociatedMonitors 返回关联监控数。
+func (s *ChannelMonitorRequestTemplateService) CountAssociatedMonitors(ctx context.Context, id int64) (int64, error) {
+ return s.repo.CountAssociatedMonitors(ctx, id)
+}
+
+// ListAssociatedMonitors 返回模板关联的所有监控简略信息。
+// 给前端 apply picker 用,handler 直接吐 JSON 不再做 join。
+func (s *ChannelMonitorRequestTemplateService) ListAssociatedMonitors(ctx context.Context, id int64) ([]*AssociatedMonitorBrief, error) {
+ if _, err := s.repo.GetByID(ctx, id); err != nil {
+ return nil, err
+ }
+ return s.repo.ListAssociatedMonitors(ctx, id)
+}
+
+// ---------- 校验 & 工具 ----------
+
+// validateTemplateCreateParams 聚合 create 入参校验,避免函数超过 30 行。
+func validateTemplateCreateParams(p ChannelMonitorRequestTemplateCreateParams) error {
+ if strings.TrimSpace(p.Name) == "" {
+ return ErrChannelMonitorTemplateMissingName
+ }
+ if err := validateProvider(p.Provider); err != nil {
+ return ErrChannelMonitorTemplateInvalidProvider
+ }
+ if err := validateBodyModeParams(p.BodyOverrideMode, p.BodyOverride); err != nil {
+ return err
+ }
+ if err := validateExtraHeaders(p.ExtraHeaders); err != nil {
+ return err
+ }
+ return nil
+}
+
+// applyTemplateUpdate 把 update params 中非 nil 字段应用到 existing 上。
+func applyTemplateUpdate(existing *ChannelMonitorRequestTemplate, p ChannelMonitorRequestTemplateUpdateParams) error {
+ if p.Name != nil {
+ name := strings.TrimSpace(*p.Name)
+ if name == "" {
+ return ErrChannelMonitorTemplateMissingName
+ }
+ existing.Name = name
+ }
+ if p.Description != nil {
+ existing.Description = strings.TrimSpace(*p.Description)
+ }
+ if p.ExtraHeaders != nil {
+ if err := validateExtraHeaders(*p.ExtraHeaders); err != nil {
+ return err
+ }
+ existing.ExtraHeaders = emptyHeadersIfNil(*p.ExtraHeaders)
+ }
+ // BodyOverrideMode / BodyOverride 联合校验:任一变化都用「更新后的值」做校验。
+ newMode := existing.BodyOverrideMode
+ newBody := existing.BodyOverride
+ if p.BodyOverrideMode != nil {
+ newMode = *p.BodyOverrideMode
+ }
+ if p.BodyOverride != nil {
+ newBody = *p.BodyOverride
+ }
+ if err := validateBodyModeParams(newMode, newBody); err != nil {
+ return err
+ }
+ existing.BodyOverrideMode = defaultBodyMode(newMode)
+ existing.BodyOverride = newBody
+ return nil
+}
+
+// validateBodyModeParams 校验 body_override_mode 合法,且 merge/replace 模式下 body_override 非空。
+func validateBodyModeParams(mode string, body map[string]any) error {
+ switch mode {
+ case "", MonitorBodyOverrideModeOff:
+ return nil
+ case MonitorBodyOverrideModeMerge, MonitorBodyOverrideModeReplace:
+ if len(body) == 0 {
+ return ErrChannelMonitorTemplateBodyRequired
+ }
+ return nil
+ default:
+ return ErrChannelMonitorTemplateInvalidBodyMode
+ }
+}
+
+// headerNameRegex 合法 header 名:RFC 7230 token(ASCII 可见字符减特殊符号)。
+var headerNameRegex = regexp.MustCompile(`^[A-Za-z0-9!#$%&'*+\-.^_` + "`" + `|~]+$`)
+
+// forbiddenHeaderNames hop-by-hop + HTTP 客户端自管的 header;禁止用户覆盖,
+// 否则会让 Go http.Client 行为异常(双重 Content-Length、连接复用错乱等)。
+var forbiddenHeaderNames = map[string]bool{
+ "host": true,
+ "content-length": true,
+ "content-encoding": true,
+ "transfer-encoding": true,
+ "connection": true,
+}
+
+// IsForbiddenHeaderName 对外暴露,checker 运行时也会再过滤一次做兜底。
+func IsForbiddenHeaderName(name string) bool {
+ return forbiddenHeaderNames[strings.ToLower(strings.TrimSpace(name))]
+}
+
+// validateExtraHeaders 校验 header 名字格式 + 黑名单。保存时就拒绝非法 header,早失败。
+func validateExtraHeaders(h map[string]string) error {
+ for k := range h {
+ if !headerNameRegex.MatchString(k) {
+ return ErrChannelMonitorTemplateHeaderInvalidName
+ }
+ if IsForbiddenHeaderName(k) {
+ return ErrChannelMonitorTemplateHeaderForbidden
+ }
+ }
+ return nil
+}
+
+// emptyHeadersIfNil 把 nil map 归一成空 map(repo 层写库时 JSONB 需要非 nil)。
+func emptyHeadersIfNil(h map[string]string) map[string]string {
+ if h == nil {
+ return map[string]string{}
+ }
+ return h
+}
+
+// defaultBodyMode 空串归一为 off。
+func defaultBodyMode(mode string) string {
+ if mode == "" {
+ return MonitorBodyOverrideModeOff
+ }
+ return mode
+}
diff --git a/backend/internal/service/channel_monitor_template_types.go b/backend/internal/service/channel_monitor_template_types.go
new file mode 100644
index 00000000..e5bf7568
--- /dev/null
+++ b/backend/internal/service/channel_monitor_template_types.go
@@ -0,0 +1,77 @@
+package service
+
+import (
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "time"
+)
+
+// ChannelMonitorRequestTemplate 请求模板(service 层模型)。
+// 作用:把一组可复用的 headers + 可选 body 覆盖配置抽出来管理,
+// 被监控「应用」时以快照方式拷贝到监控本身的同名字段。
+type ChannelMonitorRequestTemplate struct {
+ ID int64
+ Name string
+ Provider string
+ Description string
+ ExtraHeaders map[string]string
+ BodyOverrideMode string
+ BodyOverride map[string]any
+ CreatedAt time.Time
+ UpdatedAt time.Time
+}
+
+// ChannelMonitorRequestTemplateListParams 列表过滤。
+type ChannelMonitorRequestTemplateListParams struct {
+ Provider string // 空 = 全部;非空则按 provider 过滤
+}
+
+// ChannelMonitorRequestTemplateCreateParams 创建参数。
+type ChannelMonitorRequestTemplateCreateParams struct {
+ Name string
+ Provider string
+ Description string
+ ExtraHeaders map[string]string
+ BodyOverrideMode string
+ BodyOverride map[string]any
+}
+
+// ChannelMonitorRequestTemplateUpdateParams 更新参数(指针字段 = 不修改)。
+// 注意 Provider 不可修改:改 provider 会让已关联监控的 body 黑名单语义错乱。
+type ChannelMonitorRequestTemplateUpdateParams struct {
+ Name *string
+ Description *string
+ ExtraHeaders *map[string]string
+ BodyOverrideMode *string
+ BodyOverride *map[string]any
+}
+
+// 模板相关错误(命名与现有 ErrChannelMonitor* 风格保持一致)。
+var (
+ ErrChannelMonitorTemplateNotFound = infraerrors.NotFound(
+ "CHANNEL_MONITOR_TEMPLATE_NOT_FOUND", "channel monitor request template not found",
+ )
+ ErrChannelMonitorTemplateInvalidProvider = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_INVALID_PROVIDER", "template provider must be one of openai/anthropic/gemini",
+ )
+ ErrChannelMonitorTemplateMissingName = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_MISSING_NAME", "template name is required",
+ )
+ ErrChannelMonitorTemplateInvalidBodyMode = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_INVALID_BODY_MODE", "body_override_mode must be one of off/merge/replace",
+ )
+ ErrChannelMonitorTemplateBodyRequired = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_BODY_REQUIRED", "body_override is required when body_override_mode is merge or replace",
+ )
+ ErrChannelMonitorTemplateHeaderForbidden = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_HEADER_FORBIDDEN", "header name is forbidden (hop-by-hop or computed by HTTP client)",
+ )
+ ErrChannelMonitorTemplateHeaderInvalidName = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_HEADER_INVALID_NAME", "header name contains invalid characters",
+ )
+ ErrChannelMonitorTemplateProviderMismatch = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_PROVIDER_MISMATCH", "monitor provider does not match template provider",
+ )
+ ErrChannelMonitorTemplateApplyEmpty = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_APPLY_EMPTY", "monitor_ids must be a non-empty array",
+ )
+)
diff --git a/backend/internal/service/channel_monitor_types.go b/backend/internal/service/channel_monitor_types.go
new file mode 100644
index 00000000..b797a89b
--- /dev/null
+++ b/backend/internal/service/channel_monitor_types.go
@@ -0,0 +1,203 @@
+package service
+
+import "time"
+
+// MonitorBodyOverrideMode 自定义请求体处理模式。
+//
+// - off 使用 adapter 默认 body(忽略 BodyOverride)
+// - merge adapter 默认 body 与 BodyOverride 浅合并(用户优先;
+// model/messages/contents 等关键字段在 checker 黑名单内会被静默丢弃)
+// - replace 完全用 BodyOverride 作为 body;跳过 challenge 校验,
+// 改成 HTTP 2xx + 响应非空即视为可用(用户负责构造 body)
+const (
+ MonitorBodyOverrideModeOff = "off"
+ MonitorBodyOverrideModeMerge = "merge"
+ MonitorBodyOverrideModeReplace = "replace"
+)
+
+// ChannelMonitor 渠道监控配置(service 层模型,不直接暴露 ent 类型)。
+type ChannelMonitor struct {
+ ID int64
+ Name string
+ Provider string
+ Endpoint string
+ APIKey string // 解密后的明文 API Key(仅在 service 内部使用,handler 层不应直接序列化返回)
+ PrimaryModel string
+ ExtraModels []string
+ GroupName string
+ Enabled bool
+ IntervalSeconds int
+ LastCheckedAt *time.Time
+ CreatedBy int64
+ CreatedAt time.Time
+ UpdatedAt time.Time
+
+ // 请求自定义快照(来自模板拷贝 or 用户手填,运行时直接读取)
+ TemplateID *int64 // 仅用于 UI 分组 + 一键应用,运行时不用
+ ExtraHeaders map[string]string // 与 adapter 默认 headers 合并,用户优先
+ BodyOverrideMode string // off / merge / replace
+ BodyOverride map[string]any // 仅 mode != off 时使用
+
+ // APIKeyDecryptFailed 表示 APIKey 字段无法解密(密钥不一致或损坏)。
+ // 此时 APIKey 为空字符串,runner / RunCheck 必须跳过该监控并提示重填。
+ APIKeyDecryptFailed bool
+}
+
+// ChannelMonitorListParams 列表查询过滤参数。
+type ChannelMonitorListParams struct {
+ Page int
+ PageSize int
+ Provider string
+ Enabled *bool
+ Search string
+}
+
+// ChannelMonitorCreateParams 创建参数。
+type ChannelMonitorCreateParams struct {
+ Name string
+ Provider string
+ Endpoint string
+ APIKey string
+ PrimaryModel string
+ ExtraModels []string
+ GroupName string
+ Enabled bool
+ IntervalSeconds int
+ CreatedBy int64
+ TemplateID *int64
+ ExtraHeaders map[string]string
+ BodyOverrideMode string
+ BodyOverride map[string]any
+}
+
+// ChannelMonitorUpdateParams 更新参数(指针字段表示"未提供则不更新")。
+type ChannelMonitorUpdateParams struct {
+ Name *string
+ Provider *string
+ Endpoint *string
+ APIKey *string // 空字符串表示不修改;非空字符串覆盖
+ PrimaryModel *string
+ ExtraModels *[]string
+ GroupName *string
+ Enabled *bool
+ IntervalSeconds *int
+ // 自定义快照字段:指针为 nil 表示不更新,非 nil 覆盖
+ // TemplateID *(*int64):用 ** 表达三态:nil=不更新;&nil=清空;&&id=设为 id。
+ // 简化处理:用 ClearTemplate 显式标志 + TemplateID(普通指针)
+ TemplateID *int64
+ ClearTemplate bool // true 时无视 TemplateID,把监控的 template_id 置空
+ ExtraHeaders *map[string]string
+ BodyOverrideMode *string
+ BodyOverride *map[string]any
+}
+
+// CheckResult 单个模型一次检测的结果。
+type CheckResult struct {
+ Model string
+ Status string // operational / degraded / failed / error
+ LatencyMs *int
+ PingLatencyMs *int
+ Message string
+ CheckedAt time.Time
+}
+
+// UserMonitorView 用户只读视图:监控概览(含主模型最近状态 + 7d 可用率 + 附加模型最近状态)。
+type UserMonitorView struct {
+ ID int64
+ Name string
+ Provider string
+ GroupName string
+ PrimaryModel string
+ PrimaryStatus string
+ PrimaryLatencyMs *int
+ PrimaryPingLatencyMs *int // 主模型最近一次 ping 延迟
+ Availability7d float64 // 0-100
+ ExtraModels []ExtraModelStatus
+ Timeline []UserMonitorTimelinePoint // 主模型最近 N 个历史点(按 checked_at DESC,最新在前)
+}
+
+// UserMonitorTimelinePoint 用户视图 timeline 单点数据(去除 message 以减小响应体)。
+type UserMonitorTimelinePoint struct {
+ Status string `json:"status"`
+ LatencyMs *int `json:"latency_ms"`
+ PingLatencyMs *int `json:"ping_latency_ms"`
+ CheckedAt time.Time `json:"checked_at"`
+}
+
+// ExtraModelStatus 附加模型最近一次状态。
+type ExtraModelStatus struct {
+ Model string
+ Status string
+ LatencyMs *int
+}
+
+// UserMonitorDetail 用户只读视图:监控详情(含全部模型 7d/15d/30d 可用率与平均延迟)。
+type UserMonitorDetail struct {
+ ID int64
+ Name string
+ Provider string
+ GroupName string
+ Models []ModelDetail
+}
+
+// ModelDetail 单个模型的可用率/延迟统计。
+type ModelDetail struct {
+ Model string
+ LatestStatus string
+ LatestLatencyMs *int
+ Availability7d float64 // 0-100
+ Availability15d float64
+ Availability30d float64
+ AvgLatency7dMs *int
+}
+
+// ChannelMonitorHistoryRow 历史记录入库行(service 层向 repository 提交的数据)。
+type ChannelMonitorHistoryRow struct {
+ MonitorID int64
+ Model string
+ Status string
+ LatencyMs *int
+ PingLatencyMs *int
+ Message string
+ CheckedAt time.Time
+}
+
+// ChannelMonitorHistoryEntry 历史记录查询返回行(含 ent 主键 ID)。
+type ChannelMonitorHistoryEntry struct {
+ ID int64
+ Model string
+ Status string
+ LatencyMs *int
+ PingLatencyMs *int
+ Message string
+ CheckedAt time.Time
+}
+
+// ChannelMonitorLatest 最近一次检测的简明信息(用于 UserMonitorView 聚合)。
+type ChannelMonitorLatest struct {
+ Model string
+ Status string
+ LatencyMs *int
+ PingLatencyMs *int
+ CheckedAt time.Time
+}
+
+// ChannelMonitorAvailability 单个模型在某窗口内的可用率与平均延迟(用于 UserMonitorDetail 聚合)。
+type ChannelMonitorAvailability struct {
+ Model string
+ WindowDays int
+ TotalChecks int
+ OperationalChecks int // operational + degraded 视为可用
+ AvailabilityPct float64
+ AvgLatencyMs *int
+}
+
+// MonitorStatusSummary 监控状态聚合(admin list 用,单次 repo 查询消除前端 N+1)。
+// PrimaryStatus / PrimaryLatencyMs 描述主模型最近状态;Availability7d 是主模型 7 天可用率;
+// ExtraModels 描述附加模型最近状态(用于 hover 展示)。
+type MonitorStatusSummary struct {
+ PrimaryStatus string // 空字符串表示无历史
+ PrimaryLatencyMs *int
+ Availability7d float64 // 0-100,无历史时为 0
+ ExtraModels []ExtraModelStatus
+}
diff --git a/backend/internal/service/channel_monitor_validate.go b/backend/internal/service/channel_monitor_validate.go
new file mode 100644
index 00000000..16bbec71
--- /dev/null
+++ b/backend/internal/service/channel_monitor_validate.go
@@ -0,0 +1,99 @@
+package service
+
+import (
+ "context"
+ "net/url"
+ "strings"
+)
+
+// 渠道监控参数校验与归一化辅助函数。
+// 校验失败一律返回 channel_monitor_const.go 中预定义的 Err* 错误,错误信息不含具体 IP/hostname,避免泄露内网拓扑。
+
+// validateProvider 校验 provider 字符串。
+// 唯一来源于 providerAdapters:新增 provider 只需要在 channel_monitor_checker.go 注册 adapter。
+func validateProvider(p string) error {
+ if !isSupportedProvider(p) {
+ return ErrChannelMonitorInvalidProvider
+ }
+ return nil
+}
+
+// validateInterval 校验 interval_seconds 范围。
+func validateInterval(sec int) error {
+ if sec < monitorMinIntervalSeconds || sec > monitorMaxIntervalSeconds {
+ return ErrChannelMonitorInvalidInterval
+ }
+ return nil
+}
+
+// validateEndpoint 校验 endpoint:
+// - scheme 强制 https(拒绝 http,避免明文凭证 + 部分 SSRF 利用面)
+// - 必须为 origin(无 path/query/fragment),防止用户填 https://api.openai.com/v1
+// 导致 joinURL 拼出 /v1/v1/chat/completions
+// - hostname 不能是 localhost/metadata 等已知元数据 hostname
+// - 解析所有 IP,任一落在 loopback/RFC1918/link-local/ULA 段即拒绝(防 SSRF)
+//
+// 错误信息不暴露具体 IP / hostname,避免泄露内网拓扑。
+func validateEndpoint(ep string) error {
+ ep = strings.TrimSpace(ep)
+ if ep == "" {
+ return ErrChannelMonitorInvalidEndpoint
+ }
+ u, err := url.Parse(ep)
+ if err != nil {
+ return ErrChannelMonitorInvalidEndpoint
+ }
+ if u.Scheme != "https" {
+ return ErrChannelMonitorEndpointScheme
+ }
+ if u.Host == "" {
+ return ErrChannelMonitorInvalidEndpoint
+ }
+ if u.Path != "" && u.Path != "/" {
+ return ErrChannelMonitorEndpointPath
+ }
+ if u.RawQuery != "" || u.Fragment != "" {
+ return ErrChannelMonitorEndpointPath
+ }
+
+ hostname := u.Hostname()
+ ctx, cancel := context.WithTimeout(context.Background(), monitorEndpointResolveTimeout)
+ defer cancel()
+ blocked, err := isPrivateOrLoopbackHost(ctx, hostname)
+ if err != nil {
+ return ErrChannelMonitorEndpointUnreachable
+ }
+ if blocked {
+ return ErrChannelMonitorEndpointPrivate
+ }
+ return nil
+}
+
+// normalizeEndpoint 去除前后空白与末尾 `/`,保证存储统一为 origin。
+// validateEndpoint 已确保格式合法(仅 origin),这里只做最终归一化。
+func normalizeEndpoint(ep string) string {
+ ep = strings.TrimSpace(ep)
+ ep = strings.TrimRight(ep, "/")
+ return ep
+}
+
+// normalizeModels 去除空白、重复模型名。保留输入顺序(map 的迭代顺序无关)。
+func normalizeModels(in []string) []string {
+ if len(in) == 0 {
+ return []string{}
+ }
+ seen := make(map[string]struct{}, len(in))
+ out := make([]string, 0, len(in))
+ for _, m := range in {
+ m = strings.TrimSpace(m)
+ if m == "" {
+ continue
+ }
+ if _, ok := seen[m]; ok {
+ continue
+ }
+ seen[m] = struct{}{}
+ out = append(out, m)
+ }
+ return out
+}
diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go
index c29550d9..4e08df4a 100644
--- a/backend/internal/service/channel_service.go
+++ b/backend/internal/service/channel_service.go
@@ -141,17 +141,23 @@ const (
// ChannelService 渠道管理服务
type ChannelService struct {
repo ChannelRepository
+ groupRepo GroupRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
+ pricingService *PricingService // 用于「可用渠道」展示时回落到全局定价;可为 nil(测试场景)
cache atomic.Value // *channelCache
cacheSF singleflight.Group
}
-// NewChannelService 创建渠道服务实例
-func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService {
+// NewChannelService 创建渠道服务实例。
+// pricingService 仅供 ListAvailable 在渠道未配置定价时回落到全局 LiteLLM 数据;
+// 计费热路径走独立的 ModelPricingResolver,与此参数无关。可传 nil。
+func NewChannelService(repo ChannelRepository, groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, pricingService *PricingService) *ChannelService {
s := &ChannelService{
repo: repo,
+ groupRepo: groupRepo,
authCacheInvalidator: authCacheInvalidator,
+ pricingService: pricingService,
}
return s
}
@@ -299,6 +305,9 @@ func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[i
}
// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
+// 装填时对每个 Channel 统一归一化 BillingModelSource,让缓存命中的所有下游
+// (gateway routing / billing / 未来任何 cache-backed 读路径)都拿到已归一化的实体,
+// 避免"每个出口各自记得 normalize"反模式。
func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
cache := newEmptyChannelCache()
cache.groupPlatform = groupPlatforms
@@ -306,6 +315,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *
cache.loadedAt = time.Now()
for i := range channels {
+ channels[i].normalizeBillingModelSource()
ch := &channels[i]
cache.byID[ch.ID] = ch
for _, gid := range ch.GroupIDs {
@@ -516,14 +526,13 @@ func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, g
// resolveMapping 基于已查找的渠道信息解析模型映射。
// antigravity 分组依次尝试所有匹配平台,确保跨平台同名映射各自独立。
func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult {
+ // lk.channel 来自已装填的缓存,BillingModelSource 已在 populateChannelCache 阶段归一化,
+ // 这里无需重复兜底。
result := ChannelMappingResult{
MappedModel: model,
ChannelID: lk.channel.ID,
BillingModelSource: lk.channel.BillingModelSource,
}
- if result.BillingModelSource == "" {
- result.BillingModelSource = BillingModelSourceChannelMapped
- }
modelLower := strings.ToLower(model)
if mapped := lookupMappingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower); mapped != "" {
@@ -684,9 +693,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
AccountStatsPricingRules: input.AccountStatsPricingRules,
}
- if channel.BillingModelSource == "" {
- channel.BillingModelSource = BillingModelSourceChannelMapped
- }
+ channel.normalizeBillingModelSource()
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
@@ -702,12 +709,23 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
}
s.invalidateCache()
- return s.repo.GetByID(ctx, channel.ID)
+ created, err := s.repo.GetByID(ctx, channel.ID)
+ if err != nil {
+ return nil, err
+ }
+ created.normalizeBillingModelSource()
+ return created, nil
}
-// GetByID 获取渠道详情
+// GetByID 获取渠道详情。返回前统一把空 BillingModelSource 回填为 ChannelMapped,
+// 让所有 handler 无需重复处理历史空值。
func (s *ChannelService) GetByID(ctx context.Context, id int64) (*Channel, error) {
- return s.repo.GetByID(ctx, id)
+ ch, err := s.repo.GetByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ ch.normalizeBillingModelSource()
+ return ch, nil
}
// Update 更新渠道
@@ -739,7 +757,12 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
s.invalidateCache()
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
- return s.repo.GetByID(ctx, id)
+ updated, err := s.repo.GetByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ updated.normalizeBillingModelSource()
+ return updated, nil
}
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
@@ -857,7 +880,14 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
// List 获取渠道列表
func (s *ChannelService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) {
- return s.repo.List(ctx, params, status, search)
+ channels, res, err := s.repo.List(ctx, params, status, search)
+ if err != nil {
+ return nil, nil, err
+ }
+ for i := range channels {
+ channels[i].normalizeBillingModelSource()
+ }
+ return channels, res, nil
}
// modelEntry 表示一个模型模式条目(用于冲突检测)
@@ -884,12 +914,7 @@ func conflictsBetween(a, b modelEntry) bool {
// toModelEntry 将模型名转换为 modelEntry
func toModelEntry(pattern string) modelEntry {
- lower := strings.ToLower(pattern)
- isWild := strings.HasSuffix(lower, "*")
- prefix := lower
- if isWild {
- prefix = strings.TrimSuffix(lower, "*")
- }
+ prefix, isWild := splitWildcardSuffix(strings.ToLower(pattern))
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
}
diff --git a/backend/internal/service/channel_service_test.go b/backend/internal/service/channel_service_test.go
index e1345618..e737a211 100644
--- a/backend/internal/service/channel_service_test.go
+++ b/backend/internal/service/channel_service_test.go
@@ -189,11 +189,11 @@ func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context
// ---------------------------------------------------------------------------
func newTestChannelService(repo *mockChannelRepository) *ChannelService {
- return NewChannelService(repo, nil)
+ return NewChannelService(repo, nil, nil, nil)
}
func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService {
- return NewChannelService(repo, auth)
+ return NewChannelService(repo, nil, auth, nil)
}
// makeStandardRepo returns a repo that serves one active channel with anthropic pricing
diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go
index deac64d6..164861fb 100644
--- a/backend/internal/service/channel_test.go
+++ b/backend/internal/service/channel_test.go
@@ -433,3 +433,296 @@ func TestValidateIntervals_UnboundedNotLast(t *testing.T) {
require.Contains(t, err.Error(), "unbounded")
require.Contains(t, err.Error(), "last")
}
+
+func TestSupportedModels_ExactKeysAndPricing(t *testing.T) {
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 10, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(3e-6)},
+ {ID: 11, Platform: "anthropic", Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(1.5e-5)},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {
+ "claude-sonnet-4-6": "claude-sonnet-4-6",
+ "claude-opus-4-6": "claude-opus-4-6",
+ },
+ },
+ }
+
+ got := ch.SupportedModels()
+ require.Len(t, got, 2)
+ require.Equal(t, "anthropic", got[0].Platform)
+ require.Equal(t, "claude-opus-4-6", got[0].Name)
+ require.NotNil(t, got[0].Pricing)
+ require.Equal(t, int64(11), got[0].Pricing.ID)
+ require.Equal(t, "claude-sonnet-4-6", got[1].Name)
+ require.Equal(t, int64(10), got[1].Pricing.ID)
+}
+
+func TestSupportedModels_WildcardExpandedFromPricing(t *testing.T) {
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6", "claude-sonnet-4-5"}},
+ {ID: 2, Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {
+ "claude-sonnet-*": "claude-sonnet-4-6",
+ },
+ },
+ }
+
+ got := ch.SupportedModels()
+ names := make([]string, 0, len(got))
+ for _, m := range got {
+ names = append(names, m.Name)
+ }
+ require.ElementsMatch(t, []string{"claude-sonnet-4-5", "claude-sonnet-4-6", "claude-opus-4-6"}, names)
+ for _, m := range got {
+ require.NotContains(t, m.Name, "*")
+ }
+}
+
+
+func TestSupportedModels_MissingPricingKeepsNilPricing(t *testing.T) {
+ ch := &Channel{
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {"claude-sonnet-4-6": "claude-sonnet-4-6"},
+ },
+ }
+
+ got := ch.SupportedModels()
+ require.Len(t, got, 1)
+ require.Equal(t, "claude-sonnet-4-6", got[0].Name)
+ require.Nil(t, got[0].Pricing)
+}
+
+func TestSupportedModels_DedupAndSort(t *testing.T) {
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6", "claude-sonnet-4-5"}},
+ {ID: 2, Platform: "openai", Models: []string{"gpt-4o"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {
+ "claude-sonnet-4-6": "upstream-a",
+ "claude-sonnet-*": "upstream-a",
+ },
+ "openai": {"gpt-4o": "gpt-4o"},
+ },
+ }
+
+ got := ch.SupportedModels()
+ require.Len(t, got, 3)
+ require.Equal(t, "anthropic", got[0].Platform)
+ require.Equal(t, "claude-sonnet-4-5", got[0].Name)
+ require.Equal(t, "anthropic", got[1].Platform)
+ require.Equal(t, "claude-sonnet-4-6", got[1].Name)
+ require.Equal(t, "openai", got[2].Platform)
+ require.Equal(t, "gpt-4o", got[2].Name)
+}
+
+func TestSupportedModels_NilChannelAndEmpty(t *testing.T) {
+ var nilCh *Channel
+ require.Nil(t, nilCh.SupportedModels())
+
+ empty := &Channel{}
+ require.Nil(t, empty.SupportedModels())
+}
+
+func TestGetModelPricingByPlatform(t *testing.T) {
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(3e-6)},
+ {ID: 2, Platform: "openai", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(1e-6)},
+ },
+ }
+
+ ant := ch.GetModelPricingByPlatform("anthropic", "claude-sonnet-4-6")
+ require.NotNil(t, ant)
+ require.Equal(t, int64(1), ant.ID)
+
+ oa := ch.GetModelPricingByPlatform("openai", "claude-sonnet-4-6")
+ require.NotNil(t, oa)
+ require.Equal(t, int64(2), oa.ID)
+
+ require.Nil(t, ch.GetModelPricingByPlatform("gemini", "claude-sonnet-4-6"))
+}
+
+func TestSupportedModels_WildcardOnlyPricingRowsSkipped(t *testing.T) {
+ // 定价中含通配符条目(pattern),不应被当作具体模型名展开。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-*", "claude-sonnet-4-6"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {"claude-sonnet-*": "claude-sonnet-4-6"},
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 1)
+ require.Equal(t, "claude-sonnet-4-6", got[0].Name)
+ for _, m := range got {
+ require.NotContains(t, m.Name, "*")
+ }
+}
+
+func TestSupportedModels_WildcardPrefixMatchesNothing(t *testing.T) {
+ // 通配符模式无任何对应定价模型时,该平台 mapping 路不产出;
+ // 但其他平台的 pricing-only 模型仍会通过 Pass B 出现。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "openai", Models: []string{"gpt-4o"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {"gpt-foo-*": "gpt-foo-1"},
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 1)
+ require.Equal(t, "openai", got[0].Platform)
+ require.Equal(t, "gpt-4o", got[0].Name)
+}
+
+func TestSupportedModels_CrossPlatformPricingDoesNotBleed(t *testing.T) {
+ // anthropic 的通配符不应把 openai 定价行拉到 anthropic 平台下;
+ // openai 的 pricing-only 模型则正常通过 Pass B 暴露在 openai 平台下。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "openai", Models: []string{"claude-sonnet-4-6"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {"claude-sonnet-*": "x"},
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 1)
+ require.Equal(t, "openai", got[0].Platform, "不能把 openai 定价标记为 anthropic 模型")
+ require.Equal(t, "claude-sonnet-4-6", got[0].Name)
+}
+
+func TestSupportedModels_CaseInsensitiveDedup(t *testing.T) {
+ // 两行定价用不同大小写定义了同一模型,结果应去重为 1 条;首次出现的原始大小写保留。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "openai", Models: []string{"GPT-4o"}},
+ {ID: 2, Platform: "openai", Models: []string{"gpt-4o"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "openai": {"gpt-*": "x"},
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 1)
+ require.Equal(t, "GPT-4o", got[0].Name)
+}
+
+func TestSupportedModels_EmptyPlatformMapping(t *testing.T) {
+ // ModelMapping 平台 key 存在但 value 为空 map:mapping 路跳过该平台,
+ // 但 pricing 路仍会把该平台的定价模型补齐(关键修复:azcc 这种"只配定价不配映射"渠道)。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {},
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 1)
+ require.Equal(t, "anthropic", got[0].Platform)
+ require.Equal(t, "claude-sonnet-4-6", got[0].Name)
+ require.NotNil(t, got[0].Pricing)
+}
+
+func TestSupportedModels_ExactKeyUsesPricedCaseWhenAvailable(t *testing.T) {
+ // mapping key uses uppercase, pricing uses lowercase — pricing's case should win.
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "openai", Models: []string{"gpt-4o"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "openai": {"GPT-4o": "gpt-4o"},
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 1)
+ require.Equal(t, "gpt-4o", got[0].Name) // pricing's case wins
+}
+
+func TestSupportedModels_AsteriskOnlyMappingExpandsAllPriced(t *testing.T) {
+ // 映射 key 为单独的 "*":前缀为空 → 命中该平台所有定价模型(透传场景)。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "openai", Models: []string{"gpt-4o", "gpt-4o-mini"}},
+ },
+ ModelMapping: map[string]map[string]string{
+ "openai": {"*": "gpt-4o"},
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 2)
+ names := []string{got[0].Name, got[1].Name}
+ require.ElementsMatch(t, []string{"gpt-4o", "gpt-4o-mini"}, names)
+}
+
+func TestSupportedModels_PricingOnlyNoMapping(t *testing.T) {
+ // 渠道完全没配 mapping,只配了定价 —— 应该把所有定价模型作为支持模型返回。
+ // 这是修复前的核心 bug 场景(前端显示"未配置模型")。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(1.5e-5)},
+ {ID: 2, Platform: "anthropic", Models: []string{"claude-haiku-4-5"}, InputPrice: testPtrFloat64(3e-7)},
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 2)
+ require.Equal(t, "claude-haiku-4-5", got[0].Name)
+ require.NotNil(t, got[0].Pricing)
+ require.Equal(t, int64(2), got[0].Pricing.ID)
+ require.Equal(t, "claude-opus-4-6", got[1].Name)
+ require.Equal(t, int64(1), got[1].Pricing.ID)
+}
+
+func TestSupportedModels_ExactMappingUsesTargetPricing(t *testing.T) {
+ // 精确 mapping `src → target`:定价应按 target 查(实际计费的是 target),
+ // 而不是按 src 自查。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 100, Platform: "anthropic", Models: []string{"req-model"}, InputPrice: testPtrFloat64(3e-6)},
+ {ID: 200, Platform: "anthropic", Models: []string{"served-model"}, InputPrice: testPtrFloat64(1.5e-5)},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {
+ "req-model": "served-model",
+ },
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 2)
+ require.Equal(t, "req-model", got[0].Name)
+ require.NotNil(t, got[0].Pricing)
+ require.Equal(t, int64(200), got[0].Pricing.ID, "req-model 显示但定价是 served-model 的(mapping target)")
+ require.Equal(t, "served-model", got[1].Name)
+ require.Equal(t, int64(200), got[1].Pricing.ID)
+}
+
+func TestSupportedModels_ExactMappingTargetMissingFromPricing(t *testing.T) {
+ // `src → target` 但 target 不在渠道定价里 —— 结果中 src 的 Pricing 为 nil
+ // (等待 ListAvailable 阶段的全局 LiteLLM 回落填充)。
+ ch := &Channel{
+ ModelPricing: []ChannelModelPricing{
+ {ID: 1, Platform: "anthropic", Models: []string{"some-priced-model"}, InputPrice: testPtrFloat64(1.5e-5)},
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {
+ "missing-src": "missing-target",
+ },
+ },
+ }
+ got := ch.SupportedModels()
+ require.Len(t, got, 2)
+ require.Equal(t, "missing-src", got[0].Name)
+ require.Nil(t, got[0].Pricing, "target 在渠道定价中缺失时不虚假填充,留给 ListAvailable 走 LiteLLM 回落")
+ require.Equal(t, "some-priced-model", got[1].Name)
+ require.NotNil(t, got[1].Pricing)
+}
diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go
index 82fa31c4..d70379c1 100644
--- a/backend/internal/service/claude_token_provider.go
+++ b/backend/internal/service/claude_token_provider.go
@@ -17,7 +17,7 @@ const (
// ClaudeTokenCache token cache interface.
type ClaudeTokenCache = GeminiTokenCache
-// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
+// ClaudeTokenProvider manages access_token for Claude OAuth and Vertex service account accounts.
type ClaudeTokenProvider struct {
accountRepo AccountRepository
tokenCache ClaudeTokenCache
@@ -56,8 +56,11 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil {
return "", errors.New("account is nil")
}
- if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
- return "", errors.New("not an anthropic oauth account")
+ if account.Platform != PlatformAnthropic || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
+ return "", errors.New("not an anthropic oauth or service account")
+ }
+ if account.Type == AccountTypeServiceAccount {
+ return p.getServiceAccountAccessToken(ctx, account)
}
cacheKey := ClaudeTokenCacheKey(account)
@@ -157,3 +160,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
+
+func (p *ClaudeTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
+ return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
+}
diff --git a/backend/internal/service/claude_token_provider_test.go b/backend/internal/service/claude_token_provider_test.go
index 3e21f6f4..d4a4a14a 100644
--- a/backend/internal/service/claude_token_provider_test.go
+++ b/backend/internal/service/claude_token_provider_test.go
@@ -137,7 +137,7 @@ func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *A
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
- return "", errors.New("not an anthropic oauth account")
+ return "", errors.New("not an anthropic oauth or service account")
}
cacheKey := ClaudeTokenCacheKey(account)
@@ -371,7 +371,7 @@ func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
- require.Contains(t, err.Error(), "not an anthropic oauth account")
+ require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token)
}
@@ -385,7 +385,7 @@ func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
- require.Contains(t, err.Error(), "not an anthropic oauth account")
+ require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token)
}
@@ -399,7 +399,7 @@ func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
- require.Contains(t, err.Error(), "not an anthropic oauth account")
+ require.Contains(t, err.Error(), "not an anthropic oauth or service account")
require.Empty(t, token)
}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index cb452efb..bb32540b 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -18,6 +18,19 @@ const (
RoleUser = domain.RoleUser
)
+// Affiliate rebate settings
+const (
+ AffiliateRebateRateDefault = 20.0
+ AffiliateRebateRateMin = 0.0
+ AffiliateRebateRateMax = 100.0
+ AffiliateEnabledDefault = false // 邀请返利总开关默认关闭
+ AffiliateRebateFreezeHoursDefault = 0 // 0 = 不冻结(向后兼容)
+ AffiliateRebateFreezeHoursMax = 720 // 最大 30 天
+ AffiliateRebateDurationDaysDefault = 0 // 0 = 永久有效
+ AffiliateRebateDurationDaysMax = 3650 // ~10 年
+ AffiliateRebatePerInviteeCapDefault = 0.0 // 0 = 无上限
+)
+
// Platform constants
const (
PlatformAnthropic = domain.PlatformAnthropic
@@ -28,11 +41,12 @@ const (
// Account type constants
const (
- AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
- AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
- AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
- AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
- AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
+ AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
+ AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
+ AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
+ AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
+ AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名或 API Key 连接 Bedrock,由 credentials.auth_mode 区分)
+ AccountTypeServiceAccount = domain.AccountTypeServiceAccount // Google Service Account 类型账号(用于 Vertex AI)
)
// Redeem type constants
@@ -74,6 +88,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
+// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。
+const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid"
+
// Setting keys
const (
// 注册设置
@@ -84,6 +101,11 @@ const (
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
SettingKeyFrontendURL = "frontend_url" // 前端基础URL,用于生成邮件中的重置密码链接
SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册
+ SettingKeyAffiliateEnabled = "affiliate_enabled" // 邀请返利功能总开关
+ SettingKeyAffiliateRebateRate = "affiliate_rebate_rate" // 邀请返利比例(百分比,0-100)
+ SettingKeyAffiliateRebateFreezeHours = "affiliate_rebate_freeze_hours" // 返利冻结期(小时,0=不冻结)
+ SettingKeyAffiliateRebateDurationDays = "affiliate_rebate_duration_days" // 返利有效期(天,0=永久)
+ SettingKeyAffiliateRebatePerInviteeCap = "affiliate_rebate_per_invitee_cap" // 单人返利上限(0=无上限)
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
@@ -108,6 +130,24 @@ const (
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
+ // WeChat Connect OAuth 登录设置
+ SettingKeyWeChatConnectEnabled = "wechat_connect_enabled"
+ SettingKeyWeChatConnectAppID = "wechat_connect_app_id"
+ SettingKeyWeChatConnectAppSecret = "wechat_connect_app_secret"
+ SettingKeyWeChatConnectOpenAppID = "wechat_connect_open_app_id"
+ SettingKeyWeChatConnectOpenAppSecret = "wechat_connect_open_app_secret"
+ SettingKeyWeChatConnectMPAppID = "wechat_connect_mp_app_id"
+ SettingKeyWeChatConnectMPAppSecret = "wechat_connect_mp_app_secret"
+ SettingKeyWeChatConnectMobileAppID = "wechat_connect_mobile_app_id"
+ SettingKeyWeChatConnectMobileAppSecret = "wechat_connect_mobile_app_secret"
+ SettingKeyWeChatConnectOpenEnabled = "wechat_connect_open_enabled"
+ SettingKeyWeChatConnectMPEnabled = "wechat_connect_mp_enabled"
+ SettingKeyWeChatConnectMobileEnabled = "wechat_connect_mobile_enabled"
+ SettingKeyWeChatConnectMode = "wechat_connect_mode"
+ SettingKeyWeChatConnectScopes = "wechat_connect_scopes"
+ SettingKeyWeChatConnectRedirectURL = "wechat_connect_redirect_url"
+ SettingKeyWeChatConnectFrontendRedirectURL = "wechat_connect_frontend_redirect_url"
+
// Generic OIDC OAuth 登录设置
SettingKeyOIDCConnectEnabled = "oidc_connect_enabled"
SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name"
@@ -149,9 +189,33 @@ const (
SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组)
// 默认配置
- SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
- SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
- SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
+ SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
+ SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
+ SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
+ SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制)
+
+ // 第三方认证来源默认授予配置
+ SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
+ SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency"
+ SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions"
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup"
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance"
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency"
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions"
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup"
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance"
+ SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency"
+ SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions"
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup"
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance"
+ SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency"
+ SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
+ SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
@@ -198,6 +262,23 @@ const (
// SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings.
SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config"
+ // =========================
+ // Channel Monitor (渠道监控)
+ // =========================
+
+ // SettingKeyChannelMonitorEnabled is a DB-backed soft switch for the channel monitor feature.
+ // When false: runner skips scheduling and user-facing endpoints return an empty list.
+ SettingKeyChannelMonitorEnabled = "channel_monitor_enabled"
+
+ // SettingKeyChannelMonitorDefaultIntervalSeconds controls the default interval (seconds)
+ // pre-filled when creating a new channel monitor from the admin UI. Range: [15, 3600].
+ SettingKeyChannelMonitorDefaultIntervalSeconds = "channel_monitor_default_interval_seconds"
+
+ // SettingKeyAvailableChannelsEnabled is a DB-backed soft switch for the "Available Channels"
+ // user-facing aggregate view. When false: user endpoint returns an empty list and the
+ // sidebar entry is hidden. Defaults to false (opt-in feature).
+ SettingKeyAvailableChannelsEnabled = "available_channels_enabled"
+
// =========================
// Overload Cooldown (529)
// =========================
@@ -226,6 +307,12 @@ const (
// SettingKeyBetaPolicySettings stores JSON config for beta policy rules.
SettingKeyBetaPolicySettings = "beta_policy_settings"
+ // SettingKeyOpenAIFastPolicySettings stores JSON config for OpenAI
+ // service_tier (fast/flex) policy rules. Mirrors BetaPolicySettings but
+ // targets OpenAI's body-level service_tier field instead of Claude's
+ // anthropic-beta header.
+ SettingKeyOpenAIFastPolicySettings = "openai_fast_policy_settings"
+
// =========================
// Claude Code Version Check
// =========================
@@ -249,6 +336,8 @@ const (
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false)
SettingKeyEnableCCHSigning = "enable_cch_signing"
+ // SettingKeyEnableAnthropicCacheTTL1hInjection 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
+ SettingKeyEnableAnthropicCacheTTL1hInjection = "enable_anthropic_cache_ttl_1h_injection"
// Balance Low Notification
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go
index 5be1f733..428231ee 100644
--- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go
+++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go
@@ -762,8 +762,14 @@ func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(
system := gjson.GetBytes(upstream.lastBody, "system")
require.True(t, system.Exists())
require.True(t, system.IsArray(), "system should be an array")
- require.Equal(t, claudeCodeSystemPrompt, system.Array()[0].Get("text").String())
- require.Equal(t, "ephemeral", system.Array()[0].Get("cache_control.type").String())
+ arr := system.Array()
+ require.Len(t, arr, 2, "system array should have billing block + cc prompt block")
+
+ require.Contains(t, arr[0].Get("text").String(), "x-anthropic-billing-header:")
+ require.Contains(t, arr[0].Get("text").String(), "cc_version=")
+
+ require.Equal(t, claudeCodeSystemPrompt, arr[1].Get("text").String())
+ require.Equal(t, "ephemeral", arr[1].Get("cache_control.type").String())
// 原始 system prompt 应迁移至 messages 中
messages := gjson.GetBytes(upstream.lastBody, "messages")
diff --git a/backend/internal/service/gateway_anthropic_vertex_service_account_test.go b/backend/internal/service/gateway_anthropic_vertex_service_account_test.go
new file mode 100644
index 00000000..aa779805
--- /dev/null
+++ b/backend/internal/service/gateway_anthropic_vertex_service_account_test.go
@@ -0,0 +1,68 @@
+package service
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestGatewayService_BuildAnthropicVertexServiceAccountRequest(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+ c.Request.Header.Set("Authorization", "Bearer inbound-token")
+ c.Request.Header.Set("X-Api-Key", "inbound-api-key")
+ c.Request.Header.Set("Anthropic-Version", "2023-06-01")
+ c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14")
+
+ account := &Account{
+ ID: 301,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeServiceAccount,
+ Credentials: map[string]any{
+ "project_id": "vertex-proj",
+ "location": "us-east5",
+ },
+ }
+ body := []byte(`{"model":"claude-sonnet-4-5","stream":false,"max_tokens":32,"messages":[{"role":"user","content":"hello"}]}`)
+
+ svc := &GatewayService{}
+ req, err := svc.buildUpstreamRequest(
+ context.Background(),
+ c,
+ account,
+ body,
+ "vertex-token",
+ "service_account",
+ "claude-sonnet-4-5@20250929",
+ false,
+ false,
+ )
+ require.NoError(t, err)
+ require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/vertex-proj/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", req.URL.String())
+ require.Equal(t, "Bearer vertex-token", getHeaderRaw(req.Header, "authorization"))
+ require.Empty(t, getHeaderRaw(req.Header, "x-api-key"))
+ require.Empty(t, getHeaderRaw(req.Header, "anthropic-version"))
+ require.Equal(t, "interleaved-thinking-2025-05-14", getHeaderRaw(req.Header, "anthropic-beta"))
+
+ got := readRequestBodyForTest(t, req)
+ require.Equal(t, "", gjson.GetBytes(got, "model").String())
+ require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
+ require.Equal(t, "hello", gjson.GetBytes(got, "messages.0.content").String())
+}
+
+func readRequestBodyForTest(t *testing.T, req *http.Request) []byte {
+ t.Helper()
+ require.NotNil(t, req.Body)
+ body, err := io.ReadAll(req.Body)
+ require.NoError(t, err)
+ return body
+}
diff --git a/backend/internal/service/gateway_billing_block.go b/backend/internal/service/gateway_billing_block.go
new file mode 100644
index 00000000..45c307fd
--- /dev/null
+++ b/backend/internal/service/gateway_billing_block.go
@@ -0,0 +1,98 @@
+package service
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+
+ "github.com/tidwall/gjson"
+)
+
+// fingerprintSalt 是计算 cc_version 后缀指纹的盐值。
+//
+// 来源:与 Parrot src/transform/cc_mimicry.py 的 FINGERPRINT_SALT 完全一致;
+// 这是真实 Claude Code CLI 抓包推导出的常量,改动会导致 fp 与 CLI 不一致,
+// 进一步触发 Anthropic 的第三方检测。
+const fingerprintSalt = "59cf53e54c78"
+
+// computeClaudeCodeFingerprint 复刻真实 Claude Code CLI 的 cc_version 指纹算法:
+//
+// 1. 取 messages 中第一条 role=user 的纯文本(首块 text)
+// 2. 取该文本的第 4、7、20 字符(不足以 '0' 补齐)
+// 3. SHA256(SALT + chars + cc_version) 取 hex 前 3 字符
+//
+// 算法来自 Parrot src/transform/cc_mimicry.py:compute_fingerprint,与官方 CLI 字节对齐。
+// 任何偏差都会导致 cc_version=X.Y.Z.{fp} 在上游侧与真实 CLI 不一致。
+func computeClaudeCodeFingerprint(body []byte, version string) string {
+ firstText := extractFirstUserText(body)
+ indices := []int{4, 7, 20}
+ chars := make([]byte, 0, 3)
+ for _, i := range indices {
+ if i < len(firstText) {
+ chars = append(chars, firstText[i])
+ } else {
+ chars = append(chars, '0')
+ }
+ }
+ sum := sha256.Sum256([]byte(fingerprintSalt + string(chars) + version))
+ return hex.EncodeToString(sum[:])[:3]
+}
+
+// extractFirstUserText 提取 messages 中第一条 user 消息的首段 text 内容。
+// 兼容 string 和 []block 两种 content 格式。
+func extractFirstUserText(body []byte) string {
+ messages := gjson.GetBytes(body, "messages")
+ if !messages.IsArray() {
+ return ""
+ }
+ first := ""
+ messages.ForEach(func(_, msg gjson.Result) bool {
+ if msg.Get("role").String() != "user" {
+ return true
+ }
+ content := msg.Get("content")
+ if content.Type == gjson.String {
+ first = content.String()
+ return false
+ }
+ if content.IsArray() {
+ content.ForEach(func(_, block gjson.Result) bool {
+ if block.Get("type").String() == "text" {
+ first = block.Get("text").String()
+ return false
+ }
+ return true
+ })
+ return false
+ }
+ return false
+ })
+ return first
+}
+
+// buildBillingAttributionBlockJSON 构造 system 数组的 billing attribution block。
+//
+// 形态严格对齐真实 Claude Code CLI:
+//
+// {"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.92.{fp}; cc_entrypoint=cli; cch=00000;"}
+//
+// cch=00000 是签名占位符,由 signBillingHeaderCCH 在 buildUpstreamRequest 阶段
+// 替换为基于完整 body 的 xxhash64 5 位十六进制摘要。
+//
+// 此 block 不带 cache_control(与真实 CLI 一致;cache breakpoint 由后续的
+// Claude Code prompt block 承担)。
+func buildBillingAttributionBlockJSON(body []byte, cliVersion string) ([]byte, error) {
+ if cliVersion == "" {
+ return nil, fmt.Errorf("cliVersion required")
+ }
+ fp := computeClaudeCodeFingerprint(body, cliVersion)
+ text := fmt.Sprintf(
+ "x-anthropic-billing-header: cc_version=%s.%s; cc_entrypoint=cli; cch=00000;",
+ cliVersion, fp,
+ )
+ return json.Marshal(map[string]string{
+ "type": "text",
+ "text": text,
+ })
+}
diff --git a/backend/internal/service/gateway_body_order_test.go b/backend/internal/service/gateway_body_order_test.go
index 641522f0..e0c3cafd 100644
--- a/backend/internal/service/gateway_body_order_test.go
+++ b/backend/internal/service/gateway_body_order_test.go
@@ -1,13 +1,91 @@
package service
import (
+ "context"
+ "errors"
"strings"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
)
+type gatewayTTLSettingRepo struct {
+ data map[string]string
+}
+
+func (r *gatewayTTLSettingRepo) Get(context.Context, string) (*Setting, error) {
+ return nil, ErrSettingNotFound
+}
+
+func (r *gatewayTTLSettingRepo) GetValue(_ context.Context, key string) (string, error) {
+ if r == nil {
+ return "", ErrSettingNotFound
+ }
+ v, ok := r.data[key]
+ if !ok {
+ return "", ErrSettingNotFound
+ }
+ return v, nil
+}
+
+func (r *gatewayTTLSettingRepo) Set(_ context.Context, key, value string) error {
+ if r == nil {
+ return errors.New("setting repo is nil")
+ }
+ if r.data == nil {
+ r.data = map[string]string{}
+ }
+ r.data[key] = value
+ return nil
+}
+
+func (r *gatewayTTLSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string)
+ if r == nil {
+ return result, nil
+ }
+ for _, key := range keys {
+ if v, ok := r.data[key]; ok {
+ result[key] = v
+ }
+ }
+ return result, nil
+}
+
+func (r *gatewayTTLSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
+ if r == nil {
+ return errors.New("setting repo is nil")
+ }
+ if r.data == nil {
+ r.data = map[string]string{}
+ }
+ for key, value := range settings {
+ r.data[key] = value
+ }
+ return nil
+}
+
+func (r *gatewayTTLSettingRepo) GetAll(context.Context) (map[string]string, error) {
+ result := make(map[string]string)
+ if r == nil {
+ return result, nil
+ }
+ for key, value := range r.data {
+ result[key] = value
+ }
+ return result, nil
+}
+
+func (r *gatewayTTLSettingRepo) Delete(_ context.Context, key string) error {
+ if r != nil {
+ delete(r.data, key)
+ }
+ return nil
+}
+
func assertJSONTokenOrder(t *testing.T, body string, tokens ...string) {
t.Helper()
@@ -41,12 +119,13 @@ func TestNormalizeClaudeOAuthRequestBody_PreservesTopLevelFieldOrder(t *testing.
resultStr := string(result)
require.Equal(t, claude.NormalizeModelID("claude-3-5-sonnet-latest"), modelID)
- assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`)
- require.NotContains(t, resultStr, `"temperature"`)
+ assertJSONTokenOrder(t, resultStr, `"alpha"`, `"model"`, `"temperature"`, `"system"`, `"messages"`, `"omega"`, `"tools"`, `"metadata"`, `"max_tokens"`)
+ require.Contains(t, resultStr, `"temperature":0.2`)
require.NotContains(t, resultStr, `"tool_choice"`)
require.Contains(t, resultStr, `"system":"`+claudeCodeSystemPrompt+`"`)
require.Contains(t, resultStr, `"tools":[]`)
require.Contains(t, resultStr, `"metadata":{"user_id":"user-1"}`)
+ require.Contains(t, resultStr, `"max_tokens":128000`)
}
func TestInjectClaudeCodePrompt_PreservesFieldOrder(t *testing.T) {
@@ -70,3 +149,60 @@ func TestEnforceCacheControlLimit_PreservesTopLevelFieldOrder(t *testing.T) {
assertJSONTokenOrder(t, resultStr, `"alpha"`, `"system"`, `"messages"`, `"omega"`)
require.Equal(t, 4, strings.Count(resultStr, `"cache_control"`))
}
+
+func TestInjectAnthropicCacheControlTTL1h_OnlyUpdatesExistingEphemeralCacheControl(t *testing.T) {
+ body := []byte(`{"alpha":1,"cache_control":{"type":"ephemeral"},"system":[{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}},{"type":"text","text":"plain"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}},{"type":"text","text":"non","cache_control":{"type":"persistent","ttl":"5m"}}]}],"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral"}}],"omega":2}`)
+
+ result := injectAnthropicCacheControlTTL1h(body)
+ resultStr := string(result)
+
+ assertJSONTokenOrder(t, resultStr, `"alpha"`, `"cache_control"`, `"system"`, `"messages"`, `"tools"`, `"omega"`)
+ require.Equal(t, "1h", gjson.GetBytes(result, "cache_control.ttl").String())
+ require.Equal(t, "1h", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
+ require.False(t, gjson.GetBytes(result, "system.1.cache_control").Exists())
+ require.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
+ require.Equal(t, "5m", gjson.GetBytes(result, "messages.0.content.1.cache_control.ttl").String())
+ require.Equal(t, "1h", gjson.GetBytes(result, "tools.0.cache_control.ttl").String())
+}
+
+func TestGatewayCacheTTLGlobalSetting_TargetResolution(t *testing.T) {
+ repo := &gatewayTTLSettingRepo{data: map[string]string{
+ SettingKeyEnableAnthropicCacheTTL1hInjection: "true",
+ }}
+ gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
+ svc := &GatewayService{
+ settingService: NewSettingService(repo, &config.Config{}),
+ }
+ account := &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}
+
+ target, ok := svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account)
+ require.True(t, ok)
+ require.Equal(t, cacheTTLTarget5m, target)
+
+ account.Extra = map[string]any{
+ "cache_ttl_override_enabled": true,
+ "cache_ttl_override_target": "1h",
+ }
+ target, ok = svc.resolveCacheTTLUsageOverrideTarget(context.Background(), account)
+ require.True(t, ok)
+ require.Equal(t, cacheTTLTarget1h, target)
+}
+
+func TestGatewayCacheTTLGlobalSetting_RequestInjectionScope(t *testing.T) {
+ repo := &gatewayTTLSettingRepo{data: map[string]string{
+ SettingKeyEnableAnthropicCacheTTL1hInjection: "true",
+ }}
+ gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
+ svc := &GatewayService{
+ settingService: NewSettingService(repo, &config.Config{}),
+ }
+
+ require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}))
+ require.True(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeSetupToken}))
+ require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeAPIKey}))
+ require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}))
+
+ repo.data[SettingKeyEnableAnthropicCacheTTL1hInjection] = "false"
+ gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{})
+ require.False(t, svc.shouldInjectAnthropicCacheTTL1h(context.Background(), &Account{Platform: PlatformAnthropic, Type: AccountTypeOAuth}))
+}
diff --git a/backend/internal/service/gateway_forward_as_chat_completions.go b/backend/internal/service/gateway_forward_as_chat_completions.go
index 37b38f76..7ac77f77 100644
--- a/backend/internal/service/gateway_forward_as_chat_completions.go
+++ b/backend/internal/service/gateway_forward_as_chat_completions.go
@@ -61,10 +61,15 @@ func (s *GatewayService) ForwardAsChatCompletions(
// 4. Model mapping
mappedModel := originalModel
- if account.Type == AccountTypeAPIKey {
+ if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
- if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
+ if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
+ normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
+ if normalized != originalModel {
+ mappedModel = normalized
+ }
+ } else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized
@@ -85,15 +90,16 @@ func (s *GatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
- // 6. Apply Claude Code mimicry for OAuth accounts
- isClaudeCode := false // CC API is never Claude Code
+ // 6. Apply Claude Code mimicry for OAuth accounts.
+ // Chat Completions 协议进来的请求永远不是 Claude Code 客户端,所以对 OAuth 账号
+ // 必须完整执行 /v1/messages 主路径上的伪装链路(system 重写 + normalize + metadata 注入),
+ // 否则会被 Anthropic 判为第三方应用并扣 extra usage。
+ // 见 applyClaudeCodeOAuthMimicryToBody 的 godoc。
+ isClaudeCode := false
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
- if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
- !systemIncludesClaudeCodePrompt(anthropicReq.System) {
- anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
- }
+ anthropicBody = s.applyClaudeCodeOAuthMimicryToBody(ctx, c, account, anthropicBody, anthropicReq.System, mappedModel)
}
// 7. Enforce cache_control block limit
@@ -312,7 +318,14 @@ func (s *GatewayService) handleCCBufferedFromAnthropic(
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
- c.JSON(http.StatusOK, ccResp)
+ // Marshal then bytes-replace so tool name mapping is reversed at byte level
+ // (parity with Parrot non-stream flow that marshals → restore → emit).
+ if respBytes, err := json.Marshal(ccResp); err == nil {
+ respBytes = reverseToolNamesIfPresent(c, respBytes)
+ c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
+ } else {
+ c.JSON(http.StatusOK, ccResp)
+ }
return &ForwardResult{
RequestID: requestID,
@@ -383,7 +396,10 @@ func (s *GatewayService) handleCCStreamingFromAnthropic(
if err != nil {
return false
}
- if _, err := fmt.Fprint(c.Writer, sse); err != nil {
+ // Reverse tool name mapping: fake → real, per-chunk bytes.Replace.
+ // c 可能持有请求侧注入的 ToolNameRewrite;无则仅做静态前缀还原。
+ out := string(reverseToolNamesIfPresent(c, []byte(sse)))
+ if _, err := fmt.Fprint(c.Writer, out); err != nil {
return true // client disconnected
}
return false
diff --git a/backend/internal/service/gateway_forward_as_responses.go b/backend/internal/service/gateway_forward_as_responses.go
index 2c917112..8f8a1e94 100644
--- a/backend/internal/service/gateway_forward_as_responses.go
+++ b/backend/internal/service/gateway_forward_as_responses.go
@@ -58,10 +58,15 @@ func (s *GatewayService) ForwardAsResponses(
// 4. Model mapping
mappedModel := originalModel
reasoningEffort := ExtractResponsesReasoningEffortFromBody(body)
- if account.Type == AccountTypeAPIKey {
+ if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
- if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
+ if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
+ normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(originalModel))
+ if normalized != originalModel {
+ mappedModel = normalized
+ }
+ } else if mappedModel == originalModel && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(originalModel)
if normalized != originalModel {
mappedModel = normalized
@@ -82,15 +87,16 @@ func (s *GatewayService) ForwardAsResponses(
return nil, fmt.Errorf("marshal anthropic request: %w", err)
}
- // 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints)
- isClaudeCode := false // Responses API is never Claude Code
+ // 6. Apply Claude Code mimicry for OAuth accounts (non-Claude-Code endpoints).
+ // OpenAI Responses 协议进来的请求永远不是 Claude Code 客户端,所以对 OAuth 账号
+ // 必须完整执行 /v1/messages 主路径上的伪装链路(system 重写 + normalize + metadata 注入),
+ // 否则会被 Anthropic 判为第三方应用并扣 extra usage。
+ // 见 applyClaudeCodeOAuthMimicryToBody 的 godoc。
+ isClaudeCode := false
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
- if !strings.Contains(strings.ToLower(mappedModel), "haiku") &&
- !systemIncludesClaudeCodePrompt(anthropicReq.System) {
- anthropicBody = injectClaudeCodePrompt(anthropicBody, anthropicReq.System)
- }
+ anthropicBody = s.applyClaudeCodeOAuthMimicryToBody(ctx, c, account, anthropicBody, anthropicReq.System, mappedModel)
}
// 7. Enforce cache_control block limit
@@ -331,7 +337,12 @@ func (s *GatewayService) handleResponsesBufferedStreamingResponse(
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
- c.JSON(http.StatusOK, responsesResp)
+ if respBytes, err := json.Marshal(responsesResp); err == nil {
+ respBytes = reverseToolNamesIfPresent(c, respBytes)
+ c.Data(http.StatusOK, "application/json; charset=utf-8", respBytes)
+ } else {
+ c.JSON(http.StatusOK, responsesResp)
+ }
return &ForwardResult{
RequestID: requestID,
@@ -419,7 +430,8 @@ func (s *GatewayService) handleResponsesStreamingResponse(
)
continue
}
- if _, err := fmt.Fprint(c.Writer, sse); err != nil {
+ out := string(reverseToolNamesIfPresent(c, []byte(sse)))
+ if _, err := fmt.Fprint(c.Writer, out); err != nil {
logger.L().Info("forward_as_responses stream: client disconnected",
zap.String("request_id", requestID),
)
@@ -439,7 +451,8 @@ func (s *GatewayService) handleResponsesStreamingResponse(
if err != nil {
continue
}
- fmt.Fprint(c.Writer, sse) //nolint:errcheck
+ out := string(reverseToolNamesIfPresent(c, []byte(sse)))
+ fmt.Fprint(c.Writer, out) //nolint:errcheck
}
c.Writer.Flush()
}
diff --git a/backend/internal/service/gateway_messages_cache.go b/backend/internal/service/gateway_messages_cache.go
new file mode 100644
index 00000000..cb5384ba
--- /dev/null
+++ b/backend/internal/service/gateway_messages_cache.go
@@ -0,0 +1,141 @@
+package service
+
+import (
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+// stripMessageCacheControl 移除 $.messages[*].content[*].cache_control。
+// 与 Parrot _strip_message_cache_control 语义一致。
+//
+// 为什么必须整体清空:客户端(特别是 Claude Code)经常把 cache_control 打在
+// "当前最后一条 user message" 上;下一轮对话 messages 追加后,原本的最后一条
+// 变成中间某条,cache_control 还挂着就导致"前缀签名变化",破坏缓存命中。
+// 统一由代理重新打断点(addMessageCacheBreakpoints)才能在多轮间稳定。
+func stripMessageCacheControl(body []byte) []byte {
+ messages := gjson.GetBytes(body, "messages")
+ if !messages.IsArray() {
+ return body
+ }
+ msgIdx := -1
+ messages.ForEach(func(_, msg gjson.Result) bool {
+ msgIdx++
+ content := msg.Get("content")
+ if !content.IsArray() {
+ return true
+ }
+ blockIdx := -1
+ content.ForEach(func(_, block gjson.Result) bool {
+ blockIdx++
+ if !block.Get("cache_control").Exists() {
+ return true
+ }
+ path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIdx, blockIdx)
+ if next, err := sjson.DeleteBytes(body, path); err == nil {
+ body = next
+ }
+ return true
+ })
+ return true
+ })
+ return body
+}
+
+// addMessageCacheBreakpoints 在 messages 上注入两个稳定的 cache 断点:
+// 1. 最后一条 message
+// 2. 当 messages 数量 ≥ 4 时,倒数第二个 role=user 的 message
+//
+// 与 Parrot add_cache_breakpoints 一致。两个断点 + system prompt block 的断点
+// + tools[-1] 的断点共同构成最多 4 个断点(Anthropic 上限)。
+//
+// cache_control ttl 策略:
+// - 若目标 block 已有 cache_control.ttl → 不覆盖
+// - 否则写入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
+//
+// 调用前应先 stripMessageCacheControl 以保证幂等和稳定。
+func addMessageCacheBreakpoints(body []byte) []byte {
+ messages := gjson.GetBytes(body, "messages")
+ if !messages.IsArray() {
+ return body
+ }
+ arr := messages.Array()
+ if len(arr) == 0 {
+ return body
+ }
+
+ body = injectCacheControlOnLastContentBlock(body, len(arr)-1, &arr[len(arr)-1])
+
+ if len(arr) >= 4 {
+ userCount := 0
+ for i := len(arr) - 1; i >= 0; i-- {
+ if arr[i].Get("role").String() != "user" {
+ continue
+ }
+ userCount++
+ if userCount == 2 {
+ body = injectCacheControlOnLastContentBlock(body, i, &arr[i])
+ break
+ }
+ }
+ }
+
+ return body
+}
+
+// injectCacheControlOnLastContentBlock 把 cache_control 断点打在 messages[idx]
+// 的最后一个 content block 上。若 content 是 string,先升级成单块 text 数组
+// (对齐 Parrot _inject_cache_on_msg 的行为)。
+//
+// msg 是调用方已持有的 gjson.Result 快照,用于省一次 GetBytes。
+func injectCacheControlOnLastContentBlock(body []byte, idx int, msg *gjson.Result) []byte {
+ content := msg.Get("content")
+
+ if content.Type == gjson.String {
+ text := content.String()
+ blockRaw := fmt.Sprintf(
+ `[{"type":"text","text":%s,"cache_control":{"type":"ephemeral","ttl":%q}}]`,
+ mustJSONString(text), claude.DefaultCacheControlTTL,
+ )
+ if next, err := sjson.SetRawBytes(body, fmt.Sprintf("messages.%d.content", idx), []byte(blockRaw)); err == nil {
+ body = next
+ }
+ return body
+ }
+
+ if !content.IsArray() {
+ return body
+ }
+ contentArr := content.Array()
+ if len(contentArr) == 0 {
+ return body
+ }
+ lastBlockIdx := len(contentArr) - 1
+ lastBlock := contentArr[lastBlockIdx]
+
+ if cc := lastBlock.Get("cache_control"); cc.Exists() && cc.Get("ttl").String() != "" {
+ return body
+ }
+
+ pathPrefix := fmt.Sprintf("messages.%d.content.%d.cache_control", idx, lastBlockIdx)
+ existingCC := lastBlock.Get("cache_control")
+ if existingCC.Exists() {
+ if next, err := sjson.SetBytes(body, pathPrefix+".ttl", claude.DefaultCacheControlTTL); err == nil {
+ body = next
+ }
+ return body
+ }
+ raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
+ if next, err := sjson.SetRawBytes(body, pathPrefix, []byte(raw)); err == nil {
+ body = next
+ }
+ return body
+}
+
+// mustJSONString 把一个 Go string 序列化为合法 JSON string(含引号),
+// 用于 sjson.SetRawBytes 场景下手工拼 JSON。
+func mustJSONString(s string) string {
+ return fmt.Sprintf("%q", s)
+}
diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go
index e27e18aa..f3a22c1d 100644
--- a/backend/internal/service/gateway_prompt_test.go
+++ b/backend/internal/service/gateway_prompt_test.go
@@ -9,6 +9,11 @@ import (
)
func TestIsClaudeCodeClient(t *testing.T) {
+ // 合法的 legacy 格式 metadata.user_id(64位 hex + account uuid + session uuid)
+ legacyUserID := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000"
+ // 合法的 JSON 格式 metadata.user_id(2.1.78+ 版本)
+ jsonUserID := `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"123e4567-e89b-12d3-a456-426614174000"}`
+
tests := []struct {
name string
userAgent string
@@ -16,15 +21,21 @@ func TestIsClaudeCodeClient(t *testing.T) {
want bool
}{
{
- name: "Claude Code client",
+ name: "Claude Code client with legacy user_id",
userAgent: "claude-cli/1.0.62 (darwin; arm64)",
- metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
+ metadataUserID: legacyUserID,
want: true,
},
{
- name: "Claude Code without version suffix",
- userAgent: "claude-cli/2.0.0",
- metadataUserID: "session_abc",
+ name: "Claude Code client with JSON user_id",
+ userAgent: "claude-cli/2.1.92 (external, cli)",
+ metadataUserID: jsonUserID,
+ want: true,
+ },
+ {
+ name: "Claude Code case insensitive UA",
+ userAgent: "Claude-CLI/2.0.0",
+ metadataUserID: legacyUserID,
want: true,
},
{
@@ -34,21 +45,33 @@ func TestIsClaudeCodeClient(t *testing.T) {
want: false,
},
{
- name: "Different user agent",
+ name: "Claude CLI UA with invalid user_id format",
+ userAgent: "claude-cli/2.0.0",
+ metadataUserID: "fake-user-id-12345",
+ want: false,
+ },
+ {
+ name: "Different user agent with valid user_id",
userAgent: "curl/7.68.0",
- metadataUserID: "user123",
+ metadataUserID: legacyUserID,
want: false,
},
{
name: "Empty user agent",
userAgent: "",
- metadataUserID: "user123",
+ metadataUserID: legacyUserID,
want: false,
},
{
name: "Similar but not Claude CLI",
userAgent: "claude-api/1.0.0",
- metadataUserID: "user123",
+ metadataUserID: legacyUserID,
+ want: false,
+ },
+ {
+ name: "Opencode spoofing UA with arbitrary user_id",
+ userAgent: "claude-cli/2.1.92",
+ metadataUserID: "session_abc",
want: false,
},
}
@@ -378,16 +401,27 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) {
err := json.Unmarshal(result, &parsed)
require.NoError(t, err)
- // system 应为 array 格式: [{type: "text", text: "...", cache_control: {type: "ephemeral"}}]
+ // system 应为 array 格式,对齐真实 Claude Code CLI 的 2-block 形态:
+ // [0] billing attribution block (x-anthropic-billing-header: cc_version=...;)
+ // [1] Claude Code prompt block (带 cache_control)
systemArr, ok := parsed["system"].([]any)
require.True(t, ok, "system should be an array, got %T", parsed["system"])
- require.Len(t, systemArr, 1, "system array should have exactly 1 block")
- systemBlock, ok := systemArr[0].(map[string]any)
+ require.Len(t, systemArr, 2, "system array should have exactly 2 blocks (billing + cc prompt)")
+
+ billingBlock, ok := systemArr[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "text", billingBlock["type"])
+ require.Contains(t, billingBlock["text"], "x-anthropic-billing-header:")
+ require.Contains(t, billingBlock["text"], "cc_version=")
+ require.Contains(t, billingBlock["text"], "cc_entrypoint=cli")
+ require.Contains(t, billingBlock["text"], "cch=00000")
+
+ systemBlock, ok := systemArr[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", systemBlock["type"])
require.Equal(t, tt.wantSystemText, systemBlock["text"])
cc, ok := systemBlock["cache_control"].(map[string]any)
- require.True(t, ok, "system block should have cache_control")
+ require.True(t, ok, "cc prompt block should have cache_control")
require.Equal(t, "ephemeral", cc["type"])
// 检查 messages
diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go
index 55cb2c84..498336a4 100644
--- a/backend/internal/service/gateway_request.go
+++ b/backend/internal/service/gateway_request.go
@@ -962,7 +962,7 @@ func NormalizeClaudeOutputEffort(raw string) *string {
return nil
}
switch value {
- case "low", "medium", "high", "max":
+ case "low", "medium", "high", "xhigh", "max":
return &value
default:
return nil
diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go
index d262456d..40bd1186 100644
--- a/backend/internal/service/gateway_request_test.go
+++ b/backend/internal/service/gateway_request_test.go
@@ -1149,6 +1149,11 @@ func TestParseGatewayRequest_OutputEffort(t *testing.T) {
body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`,
wantEffort: "max",
},
+ {
+ name: "output_config.effort xhigh",
+ body: `{"model":"claude-opus-4-7","output_config":{"effort":"xhigh"},"messages":[]}`,
+ wantEffort: "xhigh",
+ },
{
name: "output_config without effort",
body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`,
@@ -1186,9 +1191,10 @@ func TestNormalizeClaudeOutputEffort(t *testing.T) {
{"LOW", strPtr("low")},
{"Max", strPtr("max")},
{" medium ", strPtr("medium")},
+ {"xhigh", strPtr("xhigh")},
+ {"XHIGH", strPtr("xhigh")},
{"", nil},
{"unknown", nil},
- {"xhigh", nil},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 4b4fc0bf..074013c3 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -11,6 +11,7 @@ import (
"io"
"log/slog"
mathrand "math/rand"
+ "net"
"net/http"
"net/url"
"os"
@@ -20,6 +21,7 @@ import (
"strconv"
"strings"
"sync/atomic"
+ "syscall"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -60,6 +62,11 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
+const (
+ cacheTTLTarget5m = "5m"
+ cacheTTLTarget1h = "1h"
+)
+
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
@@ -119,7 +126,7 @@ func openAIStreamEventIsTerminal(data string) bool {
return true
}
switch gjson.Get(trimmed, "type").String() {
- case "response.completed", "response.done", "response.failed":
+ case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return true
default:
return false
@@ -329,7 +336,7 @@ func isClaudeCodeCredentialScopeError(msg string) bool {
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
- claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
+ claudeCliUserAgentRe = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
@@ -435,26 +442,19 @@ func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) i
}
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
-// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
-// 或请求的模型处于限流状态时,返回 true。
-// 这确保后续请求不会继续使用不可用的账号。
+// 委托 IsSchedulable() 判断账号级可调度性(状态、配额、过载、限流等),
+// 额外检查模型级限流。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
-// Returns true when account status is error/disabled, schedulable is false,
-// within temporary unschedulable period, or the requested model is rate-limited.
-// This ensures subsequent requests won't continue using unavailable accounts.
+// Delegates to IsSchedulable() for account-level checks, plus model-level rate limiting.
func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil {
return false
}
- if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
+ if !account.IsSchedulable() {
return true
}
- if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
- return true
- }
- // 检查模型限流和 scope 限流,有限流即清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
return true
}
@@ -659,15 +659,31 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if parsed.MetadataUserID != "" {
- if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" {
+ uid := ParseMetadataUserID(parsed.MetadataUserID)
+ if uid != nil && uid.SessionID != "" {
+ slog.Info("sticky.hash_source",
+ "source", "metadata_user_id",
+ "session_id", uid.SessionID,
+ "device_id", uid.DeviceID,
+ "is_new_format", uid.IsNewFormat,
+ )
return uid.SessionID
}
+ slog.Info("sticky.hash_metadata_parse_failed",
+ "metadata_user_id", parsed.MetadataUserID,
+ "parsed_nil", uid == nil,
+ )
}
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
cacheableContent := s.extractCacheableContent(parsed)
if cacheableContent != "" {
- return s.hashContent(cacheableContent)
+ hash := s.hashContent(cacheableContent)
+ slog.Info("sticky.hash_source",
+ "source", "cacheable_content",
+ "hash", hash,
+ )
+ return hash
}
// 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串
@@ -707,7 +723,13 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
}
}
if combined.Len() > 0 {
- return s.hashContent(combined.String())
+ hash := s.hashContent(combined.String())
+ slog.Info("sticky.hash_source",
+ "source", "message_content_fallback",
+ "hash", hash,
+ "content_len", combined.Len(),
+ )
+ return hash
}
return ""
@@ -857,6 +879,7 @@ func (s *GatewayService) hashContent(content string) string {
type anthropicCacheControlPayload struct {
Type string `json:"type"`
+ TTL string `json:"ttl,omitempty"`
}
type anthropicSystemTextBlockPayload struct {
@@ -905,7 +928,10 @@ func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]b
Text: text,
}
if includeCacheControl {
- block.CacheControl = &anthropicCacheControlPayload{Type: "ephemeral"}
+ block.CacheControl = &anthropicCacheControlPayload{
+ Type: "ephemeral",
+ TTL: claude.DefaultCacheControlTTL,
+ }
}
return json.Marshal(block)
}
@@ -1081,19 +1107,52 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
}
- if gjson.GetBytes(out, "temperature").Exists() {
- if next, ok := deleteJSONPathBytes(out, "temperature"); ok {
+ // temperature:真实 Claude Code CLI 总是发送 temperature(默认 1,客户端可覆盖)。
+ // 之前的实现直接 delete 会导致 payload 缺字段,与真实 CLI 字节级不一致。
+ // 策略:客户端传了什么就透传;没传则补默认 1。
+ if !gjson.GetBytes(out, "temperature").Exists() {
+ if next, ok := setJSONValueBytes(out, "temperature", 1); ok {
out = next
modified = true
}
}
- if gjson.GetBytes(out, "tool_choice").Exists() {
- if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
+
+ // max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。
+ if !gjson.GetBytes(out, "max_tokens").Exists() {
+ if next, ok := setJSONValueBytes(out, "max_tokens", 128000); ok {
out = next
modified = true
}
}
+ // context_management:thinking.type 为 enabled/adaptive 时,真实 CLI 会自动
+ // 附带 {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}。
+ // 客户端显式传了就透传;否则按 CLI 行为补齐。
+ if !gjson.GetBytes(out, "context_management").Exists() {
+ thinkingType := gjson.GetBytes(out, "thinking.type").String()
+ if thinkingType == "enabled" || thinkingType == "adaptive" {
+ const cmDefault = `{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}`
+ if next, ok := setJSONRawBytes(out, "context_management", []byte(cmDefault)); ok {
+ out = next
+ modified = true
+ }
+ }
+ }
+
+ // tool_choice:与 Parrot 对齐,不再无条件删除。
+ // - 客户端传了 {"type":"tool","name":"X"} → 保留结构,name 由
+ // applyToolNameRewriteToBody 同步映射为假名
+ // - 其他形态(auto/any/none)原样透传
+ // 如果 body 里完全没有 tools(空数组),tool_choice 没意义时才删除
+ if !gjson.GetBytes(out, "tools").IsArray() || len(gjson.GetBytes(out, "tools").Array()) == 0 {
+ if gjson.GetBytes(out, "tool_choice").Exists() {
+ if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok {
+ out = next
+ modified = true
+ }
+ }
+ }
+
if !modified {
return body, modelID
}
@@ -1135,6 +1194,135 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
}
+// applyClaudeCodeOAuthMimicryToBody 将"非 Claude Code 客户端 + Claude OAuth 账号"
+// 路径上原本只在 /v1/messages 里做的完整伪装应用到任意 body 上。
+//
+// 这是 /v1/messages 主路径上 rewriteSystemForNonClaudeCode +
+// normalizeClaudeOAuthRequestBody 流程的通用版,供 OpenAI 协议兼容层
+// (ForwardAsChatCompletions / ForwardAsResponses) 复用。
+//
+// 未抽离之前,OpenAI 协议兼容层仅做 injectClaudeCodePrompt(前置追加),
+// 而仓内 /v1/messages 路径自己的注释明确说过"仅前置追加无法通过 Anthropic
+// 第三方检测";那条注释就是本函数存在的根因。
+//
+// 参数:
+// - ctx / c:用于读取指纹和 gateway settings;c 可为 nil(如 count_tokens)。
+// - account:必须是 OAuth 账号,且调用方已判断不是 Claude Code 客户端。
+// - body:已经 marshal 成 Anthropic /v1/messages 格式的请求体。
+// - systemRaw:body 中原始 system 字段(用于判断是否需要 rewrite)。
+// - model:最终会发给上游的模型 ID(用于 haiku 旁路 + metadata 版本选择)。
+//
+// 返回:改写后的 body。即使中间任何一步失败,也会退化成原 body(不会 panic)。
+func (s *GatewayService) applyClaudeCodeOAuthMimicryToBody(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ body []byte,
+ systemRaw any,
+ model string,
+) []byte {
+ if account == nil || !account.IsOAuth() || len(body) == 0 {
+ return body
+ }
+
+ systemRewritten := false
+ if !strings.Contains(strings.ToLower(model), "haiku") {
+ body = rewriteSystemForNonClaudeCode(body, systemRaw)
+ systemRewritten = true
+ }
+
+ normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
+
+ if s.identityService != nil && c != nil && c.Request != nil {
+ if fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header); err == nil && fp != nil {
+ mimicMPT := false
+ if s.settingService != nil {
+ _, mimicMPT, _ = s.settingService.GetGatewayForwardingSettings(ctx)
+ }
+ if !mimicMPT {
+ if uid := s.buildOAuthMetadataUserIDFromBody(ctx, account, fp, body); uid != "" {
+ normalizeOpts.injectMetadata = true
+ normalizeOpts.metadataUserID = uid
+ }
+ }
+ }
+ }
+
+ body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts)
+
+ // Phase D+E+F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
+ // 对齐 Parrot transform_request 里剩余的字段级改写。三步顺序有语义约束:
+ // 1) strip:先清除客户端的 messages[*].cache_control(多轮稳定性)
+ // 2) breakpoints:再注入 2 个断点(最后一条 + 倒数第二个 user turn)
+ // 3) tool rewrite:最后改 tools[*].name / tool_choice.name 并在 tools[-1]
+ // 上打断点;mapping 存入 gin.Context 供响应侧 bytes.Replace 还原。
+ body = stripMessageCacheControl(body)
+ body = addMessageCacheBreakpoints(body)
+
+ if rw := buildToolNameRewriteFromBody(body); rw != nil {
+ body = applyToolNameRewriteToBody(body, rw)
+ if c != nil {
+ c.Set(toolNameRewriteKey, rw)
+ }
+ } else {
+ body = applyToolsLastCacheBreakpoint(body)
+ }
+
+ return body
+}
+
+// buildOAuthMetadataUserIDFromBody 是 buildOAuthMetadataUserID 的变体,
+// 适用于调用方手上没有 ParsedRequest 的场景(如 OpenAI 协议兼容层)。
+//
+// 与 buildOAuthMetadataUserID 的唯一区别:
+// - session hash 从 body 本体按同样规则重算,而不是读取 ParsedRequest 缓存值。
+// - 如果 body 里已经存在 metadata.user_id,则返回空(由 ensureClaudeOAuthMetadataUserID
+// 自行决定是否覆盖)。
+func (s *GatewayService) buildOAuthMetadataUserIDFromBody(
+ ctx context.Context,
+ account *Account,
+ fp *Fingerprint,
+ body []byte,
+) string {
+ _ = ctx
+ if account == nil {
+ return ""
+ }
+ if existing := gjson.GetBytes(body, "metadata.user_id").String(); existing != "" {
+ return ""
+ }
+
+ userID := strings.TrimSpace(account.GetClaudeUserID())
+ if userID == "" && fp != nil {
+ userID = fp.ClientID
+ }
+ if userID == "" {
+ userID = generateClientID()
+ }
+
+ sessionID := uuid.NewString()
+ if hash := hashBodyForSessionSeed(body); hash != "" {
+ sessionID = generateSessionUUID(fmt.Sprintf("%d::%s", account.ID, hash))
+ }
+
+ var uaVersion string
+ if fp != nil {
+ uaVersion = ExtractCLIVersion(fp.UserAgent)
+ }
+ accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
+ return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion)
+}
+
+// hashBodyForSessionSeed 为 sessionID 提供一个稳定但仅对本次请求特征化的种子。
+// 复用 SHA-256 + 截断,与 generateSessionUUID 的输入格式对齐。
+func hashBodyForSessionSeed(body []byte) string {
+ if len(body) == 0 {
+ return ""
+ }
+ sum := sha256.Sum256(body)
+ return fmt.Sprintf("%x", sum[:16])
+}
+
// GenerateSessionUUID creates a deterministic UUID4 from a seed string.
func GenerateSessionUUID(seed string) string {
return generateSessionUUID(seed)
@@ -1245,14 +1433,29 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
var stickyAccountID int64
+ var stickySource string
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
stickyAccountID = prefetch
+ stickySource = "prefetch"
} else if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
stickyAccountID = accountID
+ stickySource = "cache"
}
}
+ // [DEBUG-STICKY] 调度器入口日志
+ slog.Info("sticky.scheduler_entry",
+ "group_id", derefGroupID(groupID),
+ "session_hash", shortSessionHash(sessionHash),
+ "sticky_account_id", stickyAccountID,
+ "sticky_source", stickySource,
+ "model", requestedModel,
+ "load_batch", cfg.LoadBatchEnabled,
+ "has_concurrency_svc", s.concurrencyService != nil,
+ "excluded_count", len(excludedIDs),
+ )
+
if s.debugModelRoutingEnabled() && requestedModel != "" {
groupPlatform := ""
if group != nil {
@@ -1428,6 +1631,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if len(routingCandidates) > 0 {
// 1.5. 在路由账号范围内检查粘性会话
if sessionHash != "" && stickyAccountID > 0 {
+ slog.Debug("sticky.layer1_5_checking",
+ "sticky_account_id", stickyAccountID,
+ "in_routing_list", containsInt64(routingAccountIDs, stickyAccountID),
+ "is_excluded", isExcluded(stickyAccountID),
+ "in_account_map", func() bool { _, ok := accountByID[stickyAccountID]; return ok }(),
+ "session", shortSessionHash(sessionHash),
+ )
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
// 粘性账号在路由列表中,优先使用
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
@@ -1451,6 +1661,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
stickyCacheMissReason = "session_limit"
// 继续到负载感知选择
} else {
+ slog.Debug("sticky.layer1_5_hit",
+ "account_id", stickyAccountID,
+ "session", shortSessionHash(sessionHash),
+ "result", "slot_acquired",
+ )
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
}
@@ -1601,27 +1816,65 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// 检查账户是否需要清理粘性会话绑定
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
+ slog.Debug("sticky.layer1_5_no_routing_clear",
+ "account_id", accountID,
+ "reason", "should_clear_sticky_session",
+ "session", shortSessionHash(sessionHash),
+ )
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) &&
- s.isAccountAllowedForPlatform(account, platform, useMixed) &&
- (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
- s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
- s.isAccountSchedulableForQuota(account) &&
- s.isAccountSchedulableForWindowCost(ctx, account, true) &&
- s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
+ // 注意:不再检查 isAccountInGroup,因为 accountByID 已经从按分组过滤的
+ // accounts 列表构建,账号一定在分组内。而 scheduler snapshot 缓存
+ // 反序列化后 AccountGroups 字段为空,导致 isAccountInGroup 永远返回 false。
+ platformOK := s.isAccountAllowedForPlatform(account, platform, useMixed)
+ modelSupported := requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)
+ modelSchedulable := s.isAccountSchedulableForModelSelection(ctx, account, requestedModel)
+ quotaOK := s.isAccountSchedulableForQuota(account)
+ windowCostOK := s.isAccountSchedulableForWindowCost(ctx, account, true)
+ rpmOK := s.isAccountSchedulableForRPM(ctx, account, true)
+ schedulable := s.isAccountSchedulableForSelection(account)
+
+ slog.Debug("sticky.layer1_5_no_routing_checks",
+ "account_id", accountID,
+ "session", shortSessionHash(sessionHash),
+ "clear_sticky", clearSticky,
+ "schedulable", schedulable,
+ "platform_ok", platformOK,
+ "model_supported", modelSupported,
+ "model_schedulable", modelSchedulable,
+ "quota_ok", quotaOK,
+ "window_cost_ok", windowCostOK,
+ "rpm_ok", rpmOK,
+ )
+
+ if !clearSticky && platformOK && modelSupported && modelSchedulable && quotaOK && windowCostOK && rpmOK && schedulable {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
+ slog.Debug("sticky.layer1_5_no_routing_miss",
+ "account_id", accountID,
+ "reason", "session_limit",
+ "session", shortSessionHash(sessionHash),
+ )
} else {
+ slog.Debug("sticky.layer1_5_no_routing_hit",
+ "account_id", accountID,
+ "session", shortSessionHash(sessionHash),
+ "result", "slot_acquired",
+ )
if s.cache != nil {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
}
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
}
+ } else {
+ slog.Debug("sticky.layer1_5_no_routing_slot_busy",
+ "account_id", accountID,
+ "session", shortSessionHash(sessionHash),
+ )
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
@@ -1630,6 +1883,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
} else {
+ slog.Debug("sticky.layer1_5_no_routing_hit",
+ "account_id", accountID,
+ "session", shortSessionHash(sessionHash),
+ "result", "wait_plan",
+ )
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
@@ -1638,12 +1896,42 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
})
}
}
+ } else if !clearSticky {
+ slog.Debug("sticky.layer1_5_no_routing_miss",
+ "account_id", accountID,
+ "reason", "gate_check_failed",
+ "session", shortSessionHash(sessionHash),
+ )
}
+ } else {
+ slog.Debug("sticky.layer1_5_no_routing_miss",
+ "account_id", accountID,
+ "reason", "account_not_in_map",
+ "session", shortSessionHash(sessionHash),
+ )
}
}
+ } else if len(routingAccountIDs) == 0 && sessionHash != "" {
+ slog.Debug("sticky.layer1_5_no_routing_skip",
+ "sticky_account_id", stickyAccountID,
+ "is_excluded", func() bool { return stickyAccountID > 0 && isExcluded(stickyAccountID) }(),
+ "session", shortSessionHash(sessionHash),
+ "reason", func() string {
+ if stickyAccountID == 0 {
+ return "no_sticky_binding"
+ }
+ return "sticky_account_excluded"
+ }(),
+ )
}
// ============ Layer 2: 负载感知选择 ============
+ slog.Debug("sticky.layer2_fallback",
+ "session", shortSessionHash(sessionHash),
+ "sticky_account_id", stickyAccountID,
+ "reason", "sticky_not_used_falling_back_to_load_balance",
+ "total_accounts", len(accounts),
+ )
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
@@ -3438,7 +3726,11 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
- requestedModel = claude.NormalizeModelID(requestedModel)
+ if account.Type == AccountTypeServiceAccount {
+ requestedModel = normalizeVertexAnthropicModelID(claude.NormalizeModelID(requestedModel))
+ } else {
+ requestedModel = claude.NormalizeModelID(requestedModel)
+ }
}
// 其他平台使用账户的模型支持检查
return account.IsModelSupported(requestedModel)
@@ -3458,6 +3750,18 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
return apiKey, "apikey", nil
case AccountTypeBedrock:
return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理
+ case AccountTypeServiceAccount:
+ if account.Platform != PlatformAnthropic {
+ return "", "", fmt.Errorf("unsupported service account platform: %s", account.Platform)
+ }
+ if s.claudeTokenProvider == nil {
+ return "", "", errors.New("claude token provider not configured")
+ }
+ accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return "", "", err
+ }
+ return accessToken, "service_account", nil
default:
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
}
@@ -3550,23 +3854,19 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
}
}
-// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
-// 简化判断:User-Agent 匹配 + metadata.user_id 存在
+// isClaudeCodeClient 判断请求是否来自真正的 Claude Code 客户端。
+// 判定条件:
+// 1. User-Agent 匹配 claude-cli/X.Y.Z(大小写不敏感)
+// 2. metadata.user_id 符合 Claude Code 格式(legacy 或 JSON 格式)
+//
+// 只检查 metadata.user_id 非空不够严格:第三方工具(opencode 等)可能伪造 UA
+// 并附带任意 metadata.user_id 字符串,从而绕过 mimicry。必须通过 ParseMetadataUserID
+// 验证格式才能确认是真正的 Claude Code 客户端。
func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
- if metadataUserID == "" {
+ if !claudeCliUserAgentRe.MatchString(userAgent) {
return false
}
- return claudeCliUserAgentRe.MatchString(userAgent)
-}
-
-func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
- if IsClaudeCodeClient(ctx) {
- return true
- }
- if parsed == nil || c == nil {
- return false
- }
- return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
+ return ParseMetadataUserID(metadataUserID) != nil
}
// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil),
@@ -3745,17 +4045,20 @@ func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
originalSystemText = strings.Join(parts, "\n\n")
}
- // 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致)
- // 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。
- // 使用 string 格式会被 Anthropic 检测为第三方应用。
- claudeCodeSystemBlock := []map[string]any{
- {
- "type": "text",
- "text": claudeCodeSystemPrompt,
- "cache_control": map[string]string{"type": "ephemeral"},
- },
+ // 2. 构造 system 数组,对齐真实 Claude Code CLI 的 2-block 形态:
+ // [0] billing attribution block(cc_version={cliVer}.{fp}; cc_entrypoint=cli; cch=00000;)
+ // [1] "You are Claude Code..." prompt block(带 cache_control 作为稳定缓存断点)
+ //
+ // billing block 的 cch=00000 是占位符,会被 buildUpstreamRequest 里的
+ // signBillingHeaderCCH 替换成 xxhash64 签名。缺失 billing block 的系统 payload
+ // 是 Anthropic 判定第三方的关键信号之一(真实 CLI 每个请求都带)。
+ billingBlock, billingErr := buildBillingAttributionBlockJSON(body, claude.CLICurrentVersion)
+ ccPromptBlock, ccErr := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true)
+ if billingErr != nil || ccErr != nil {
+ logger.LegacyPrintf("service.gateway", "Warning: failed to build system blocks (billing=%v, cc=%v)", billingErr, ccErr)
+ return body
}
- out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock)
+ out, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw([][]byte{billingBlock, ccPromptBlock}))
if !ok {
logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt")
return body
@@ -3928,6 +4231,87 @@ func enforceCacheControlLimit(body []byte) []byte {
return body
}
+// injectAnthropicCacheControlTTL1h 将已有 ephemeral cache_control 块的 ttl 强制写为 1h。
+// 仅修改已经存在的 cache_control,不新增缓存断点。
+func injectAnthropicCacheControlTTL1h(body []byte) []byte {
+ return forceEphemeralCacheControlTTL(body, cacheTTLTarget1h)
+}
+
+func forceEphemeralCacheControlTTL(body []byte, ttl string) []byte {
+ if len(body) == 0 || ttl == "" {
+ return body
+ }
+ out := body
+ var paths []string
+ addPath := func(path string, value gjson.Result) {
+ cc := value.Get("cache_control")
+ if !cc.Exists() || cc.Get("type").String() != "ephemeral" {
+ return
+ }
+ if cc.Get("ttl").String() == ttl {
+ return
+ }
+ paths = append(paths, path+".cache_control.ttl")
+ }
+
+ if topCC := gjson.GetBytes(body, "cache_control"); topCC.Exists() && topCC.Get("type").String() == "ephemeral" && topCC.Get("ttl").String() != ttl {
+ paths = append(paths, "cache_control.ttl")
+ }
+
+ system := gjson.GetBytes(body, "system")
+ if system.IsArray() {
+ idx := -1
+ system.ForEach(func(_, block gjson.Result) bool {
+ idx++
+ addPath(fmt.Sprintf("system.%d", idx), block)
+ return true
+ })
+ }
+
+ messages := gjson.GetBytes(body, "messages")
+ if messages.IsArray() {
+ msgIdx := -1
+ messages.ForEach(func(_, msg gjson.Result) bool {
+ msgIdx++
+ content := msg.Get("content")
+ if !content.IsArray() {
+ return true
+ }
+ contentIdx := -1
+ content.ForEach(func(_, block gjson.Result) bool {
+ contentIdx++
+ addPath(fmt.Sprintf("messages.%d.content.%d", msgIdx, contentIdx), block)
+ return true
+ })
+ return true
+ })
+ }
+
+ tools := gjson.GetBytes(body, "tools")
+ if tools.IsArray() {
+ idx := -1
+ tools.ForEach(func(_, tool gjson.Result) bool {
+ idx++
+ addPath(fmt.Sprintf("tools.%d", idx), tool)
+ return true
+ })
+ }
+
+ for _, path := range paths {
+ if next, err := sjson.SetBytes(out, path, ttl); err == nil {
+ out = next
+ }
+ }
+ return out
+}
+
+func (s *GatewayService) shouldInjectAnthropicCacheTTL1h(ctx context.Context, account *Account) bool {
+ if account == nil || !account.IsAnthropicOAuthOrSetupToken() || s == nil || s.settingService == nil {
+ return false
+ }
+ return s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx)
+}
+
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now()
@@ -3992,15 +4376,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
})
}
- isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
+ // Claude Code 客户端判定:UA 匹配 claude-cli/* 且携带 metadata.user_id。
+ // 真正的 Claude Code 客户端自带完整的 system prompt、cache_control 断点和 header,
+ // 不需要代理做任何 body 级别的 mimicry;强行替换反而会破坏客户端的缓存策略
+ // (长 system prompt 被替换为 ~45 tokens 的短 prompt,低于 Anthropic 1024 token
+ // 最低缓存门槛,导致系统级缓存失效)。
+ //
+ // 对于非 Claude Code 的第三方客户端(opencode 等),仍然走完整 mimicry。
+ isClaudeCode := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
- // 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages
- // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
+ // 与 Parrot 对齐:OAuth 账号无条件重写 system(即使客户端已发了 Claude Code
+ // 风格的 system prompt)。原因:第三方工具(opencode 等)会发 "You are Claude
+ // Code..." system prompt 但缺少 billing attribution block,导致 Anthropic
+ // 检测到"有 CC prompt 但无 billing block"的不一致而判为 third-party。
+ // Parrot 的 transform_request 从不检查客户端 system 内容,直接覆盖。
systemRewritten := false
- if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
- !systemIncludesClaudeCodePrompt(parsed.System) {
+ if !strings.Contains(strings.ToLower(reqModel), "haiku") {
body = rewriteSystemForNonClaudeCode(body, parsed.System)
systemRewritten = true
}
@@ -4024,6 +4417,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
+
+ // D/E/F: messages cache 策略 + 工具名混淆 + tools[-1] 断点
+ // 与 forward_as_chat_completions / forward_as_responses 路径对齐,
+ // 保证原生 /v1/messages 路径也经过完整的 Parrot 字段级改写。
+ body = stripMessageCacheControl(body)
+ body = addMessageCacheBreakpoints(body)
+ if rw := buildToolNameRewriteFromBody(body); rw != nil {
+ body = applyToolNameRewriteToBody(body, rw)
+ c.Set(toolNameRewriteKey, rw)
+ } else {
+ body = applyToolsLastCacheBreakpoint(body)
+ }
}
// 强制执行 cache_control 块数量限制(最多 4 个)
@@ -4040,6 +4445,18 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
mappingSource = "account"
}
}
+ if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
+ if candidate, matched := account.ResolveMappedModel(reqModel); matched {
+ mappedModel = candidate
+ mappingSource = "account"
+ } else {
+ normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(reqModel))
+ if normalized != reqModel {
+ mappedModel = normalized
+ mappingSource = "vertex"
+ }
+ }
+ }
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
normalized := claude.NormalizeModelID(reqModel)
if normalized != reqModel {
@@ -4054,6 +4471,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
}
+ if s.shouldInjectAnthropicCacheTTL1h(ctx, account) {
+ body = injectAnthropicCacheControlTTL1h(body)
+ }
+
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
@@ -4962,7 +5383,8 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
}
if !clientDisconnected {
- if _, err := io.WriteString(w, line); err != nil {
+ restored := string(reverseToolNamesIfPresent(c, []byte(line)))
+ if _, err := io.WriteString(w, restored); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else if _, err := io.WriteString(w, "\n"); err != nil {
@@ -5132,6 +5554,7 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
if contentType == "" {
contentType = "application/json"
}
+ body = reverseToolNamesIfPresent(c, body)
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
@@ -5507,6 +5930,10 @@ func (s *GatewayService) handleBedrockNonStreamingResponse(
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
+ if account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount {
+ return s.buildUpstreamRequestAnthropicVertex(ctx, c, account, body, token, modelID, reqStream)
+ }
+
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
@@ -5587,13 +6014,19 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
setHeaderRaw(req.Header, "x-api-key", token)
}
- // 白名单透传headers(恢复真实 wire casing)
- for key, values := range clientHeaders {
- lowerKey := strings.ToLower(key)
- if allowedHeaders[lowerKey] {
- wireKey := resolveWireCasing(key)
- for _, v := range values {
- addHeaderRaw(req.Header, wireKey, v)
+ // 白名单透传 headers
+ // OAuth mimicry 路径:跳过客户端 header 透传,与 Parrot 对齐。
+ // Parrot 的 build_upstream_headers 只发 9 个精确 header,不透传任何客户端 header。
+ // 透传客户端 header 会引入不一致的 x-stainless-* / anthropic-beta / user-agent /
+ // x-claude-code-session-id 等值,和我们注入的伪装 header 冲突,被 Anthropic 判 third-party。
+ if tokenType != "oauth" || !mimicClaudeCode {
+ for key, values := range clientHeaders {
+ lowerKey := strings.ToLower(key)
+ if allowedHeaders[lowerKey] {
+ wireKey := resolveWireCasing(key)
+ for _, v := range values {
+ addHeaderRaw(req.Header, wireKey, v)
+ }
}
}
}
@@ -5634,7 +6067,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// Haiku models are exempt from third-party detection and don't need it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
if !strings.Contains(strings.ToLower(modelID), "haiku") {
- requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking}
+ requiredBetas = claude.FullClaudeCodeMimicryBetas()
}
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
} else {
@@ -5687,6 +6120,60 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
return req, nil
}
+func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ body []byte,
+ token string,
+ modelID string,
+ reqStream bool,
+) (*http.Request, error) {
+ vertexBody, err := buildVertexAnthropicRequestBody(body)
+ if err != nil {
+ return nil, err
+ }
+ setOpsUpstreamRequestBody(c, vertexBody)
+ fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream)
+ if err != nil {
+ return nil, err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody))
+ if err != nil {
+ return nil, err
+ }
+
+ if c != nil && c.Request != nil {
+ for key, values := range c.Request.Header {
+ lowerKey := strings.ToLower(strings.TrimSpace(key))
+ if !allowedHeaders[lowerKey] || lowerKey == "anthropic-version" {
+ continue
+ }
+ wireKey := resolveWireCasing(key)
+ for _, v := range values {
+ addHeaderRaw(req.Header, wireKey, v)
+ }
+ }
+ }
+
+ req.Header.Del("authorization")
+ req.Header.Del("x-api-key")
+ req.Header.Del("x-goog-api-key")
+ req.Header.Del("cookie")
+ req.Header.Del("anthropic-version")
+ setHeaderRaw(req.Header, "authorization", "Bearer "+token)
+ setHeaderRaw(req.Header, "content-type", "application/json")
+
+ s.debugLogGatewaySnapshot("UPSTREAM_FORWARD_VERTEX_ANTHROPIC", req.Header, vertexBody, map[string]string{
+ "url": req.URL.String(),
+ "token_type": "service_account",
+ "model": modelID,
+ "stream": strconv.FormatBool(reqStream),
+ })
+
+ return req, nil
+}
+
// getBetaHeader 处理anthropic-beta header
// 对于OAuth账号,需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
@@ -6106,6 +6593,11 @@ func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) {
if isStream {
setHeaderRaw(req.Header, "x-stainless-helper-method", "stream")
}
+ // Real Claude CLI 每个请求都会生成一个新的 UUID 放在 x-client-request-id。
+ // 上游会以此作为会话/请求指纹的一部分,缺失或重复都可能触发第三方判定。
+ if getHeaderRaw(req.Header, "x-client-request-id") == "" {
+ setHeaderRaw(req.Header, "x-client-request-id", uuid.NewString())
+ }
}
func truncateForLog(b []byte, maxBytes int) string {
@@ -6242,6 +6734,49 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
return false
}
+// sanitizeStreamError 返回不含网络地址的客户端可见错误描述。
+// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
+// 服务器地址(例如 "read tcp 10.0.0.1:54321->52.1.2.3:443: read: connection
+// reset by peer")。该函数只保留可识别的错误类别,原始 err 仍在调用点写入日志。
+func sanitizeStreamError(err error) string {
+ if err == nil {
+ return ""
+ }
+ switch {
+ case errors.Is(err, io.ErrUnexpectedEOF):
+ return "unexpected EOF"
+ case errors.Is(err, io.EOF):
+ return "EOF"
+ case errors.Is(err, context.Canceled):
+ return "canceled"
+ case errors.Is(err, context.DeadlineExceeded):
+ return "deadline exceeded"
+ case errors.Is(err, syscall.ECONNRESET):
+ return "connection reset by peer"
+ case errors.Is(err, syscall.ECONNABORTED):
+ return "connection aborted"
+ case errors.Is(err, syscall.ETIMEDOUT):
+ return "connection timed out"
+ case errors.Is(err, syscall.EPIPE):
+ return "broken pipe"
+ case errors.Is(err, syscall.ECONNREFUSED):
+ return "connection refused"
+ }
+ var netErr *net.OpError
+ if errors.As(err, &netErr) {
+ if netErr.Timeout() {
+ if netErr.Op != "" {
+ return netErr.Op + " timeout"
+ }
+ return "i/o timeout"
+ }
+ if netErr.Op != "" {
+ return netErr.Op + " network error"
+ }
+ }
+ return "upstream connection error"
+}
+
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
func ExtractUpstreamErrorMessage(body []byte) string {
@@ -6679,14 +7214,31 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
lastDataAt := time.Now()
- // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
+ // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)。
+ // 事件格式遵循 Anthropic SSE 标准:{"type":"error","error":{"type":,"message":}}
+ // 这样 Anthropic SDK / Claude Code 等客户端能按标准 error 类型解析,UI 能显示具体错误文案,
+ // 服务端 ExtractUpstreamErrorMessage 也能从透传的 body 中提取 message。
errorEventSent := false
- sendErrorEvent := func(reason string) {
+ sendErrorEvent := func(reason, message string) {
if errorEventSent {
return
}
errorEventSent = true
- _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
+ if message == "" {
+ message = reason
+ }
+ body, err := json.Marshal(map[string]any{
+ "type": "error",
+ "error": map[string]string{
+ "type": reason,
+ "message": message,
+ },
+ })
+ if err != nil {
+ // json.Marshal 不可能在已知 string-only 输入上失败,保守 fallback
+ body = []byte(fmt.Sprintf(`{"type":"error","error":{"type":%q,"message":%q}}`, reason, message))
+ }
+ _, _ = fmt.Fprintf(w, "event: error\ndata: %s\n\n", body)
flusher.Flush()
}
@@ -6763,9 +7315,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
}
- // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
- if account.IsCacheTTLOverrideEnabled() {
- overrideTarget := account.GetCacheTTLOverrideTarget()
+ // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类。
+ // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
+ if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
@@ -6846,10 +7398,32 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 客户端未断开,正常的错误处理
if errors.Is(ev.err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
- sendErrorEvent("response_too_large")
+ sendErrorEvent("response_too_large", fmt.Sprintf("upstream SSE line exceeded %d bytes", maxLineSize))
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
- sendErrorEvent("stream_read_error")
+ // 上游中途读错误(unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY):
+ // 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。
+ // 已经开始写流时 SSE 协议无 resume,只能透传错误事件给客户端。
+ // 注意:面向客户端的 disconnectMsg 必须用 sanitizeStreamError 剥离地址,
+ // 默认 *net.OpError 的 Error() 会泄露内部 IP/端口和上游地址。完整 ev.err
+ // 仅在下方 LegacyPrintf 内部日志中保留供运维诊断。
+ disconnectMsg := "upstream stream disconnected: " + sanitizeStreamError(ev.err)
+ if !c.Writer.Written() {
+ logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err)
+ body, _ := json.Marshal(map[string]any{
+ "type": "error",
+ "error": map[string]string{
+ "type": "upstream_disconnected",
+ "message": disconnectMsg,
+ },
+ })
+ return nil, &UpstreamFailoverError{
+ StatusCode: http.StatusBadGateway,
+ ResponseBody: body,
+ RetryableOnSameAccount: true,
+ }
+ }
+ sendErrorEvent("stream_read_error", disconnectMsg)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
@@ -6871,7 +7445,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
for _, block := range outputBlocks {
if !clientDisconnected {
- if _, werr := fmt.Fprint(w, block); werr != nil {
+ restored := reverseToolNamesIfPresent(c, []byte(block))
+ if _, werr := fmt.Fprint(w, string(restored)); werr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
break
@@ -6907,7 +7482,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
- sendErrorEvent("stream_timeout")
+ sendErrorEvent("stream_timeout", fmt.Sprintf("upstream stream idle for %s", streamInterval))
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
@@ -7149,6 +7724,19 @@ func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool {
return true
}
+func (s *GatewayService) resolveCacheTTLUsageOverrideTarget(ctx context.Context, account *Account) (string, bool) {
+ if account == nil {
+ return "", false
+ }
+ if account.IsCacheTTLOverrideEnabled() {
+ return account.GetCacheTTLOverrideTarget(), true
+ }
+ if account.IsAnthropicOAuthOrSetupToken() && s != nil && s.settingService != nil && s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) {
+ return cacheTTLTarget5m, true
+ }
+ return "", false
+}
+
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -7185,9 +7773,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
- // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
- if account.IsCacheTTLOverrideEnabled() {
- overrideTarget := account.GetCacheTTLOverrideTarget()
+ // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类。
+ // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
+ if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
// 同步更新 body JSON 中的嵌套 cache_creation 对象
if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil {
@@ -7213,6 +7801,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
+ body = reverseToolNamesIfPresent(c, body)
+
// 写入响应
c.Data(resp.StatusCode, contentType, body)
@@ -7317,8 +7907,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
cost := p.Cost
if p.IsSubscriptionBill {
- if cost.TotalCost > 0 {
- if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
+ // Subscription usage tracked by ActualCost so group rate multiplier
+ // consumes the quota at the expected speed.
+ if cost.ActualCost > 0 {
+ if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
}
}
@@ -7417,9 +8009,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
}
}
+ // Record subscription / balance cost using ActualCost so the group (and any
+ // user-specific) rate multiplier consumes subscription quota at the expected
+ // speed. TotalCost remains the raw (pre-multiplier) value; downstream guards
+ // on "> 0" still correctly skip free subscriptions (RateMultiplier == 0).
if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 {
cmd.SubscriptionID = &p.Subscription.ID
- cmd.SubscriptionCost = p.Cost.TotalCost
+ cmd.SubscriptionCost = p.Cost.ActualCost
} else if p.Cost.ActualCost > 0 {
cmd.BalanceCost = p.Cost.ActualCost
}
@@ -7478,8 +8074,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu
}
if p.IsSubscriptionBill {
- if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
- deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost)
+ if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
+ deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost)
}
} else if p.Cost.ActualCost > 0 && p.User != nil {
deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost)
@@ -7747,10 +8343,11 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
result.Usage.InputTokens = 0
}
- // Cache TTL Override: 确保计费时 token 分类与账号设置一致
+ // Cache TTL Override: 确保计费时 token 分类与账号设置一致。
+ // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。
cacheTTLOverridden := false
- if account.IsCacheTTLOverrideEnabled() {
- applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
+ if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok {
+ applyCacheTTLOverride(&result.Usage, overrideTarget)
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
@@ -8195,12 +8792,20 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// Pre-filter: strip empty text blocks to prevent upstream 400.
body = StripEmptyTextBlocks(body)
- isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
- shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
+ isClaudeCodeCT := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
+ shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCodeCT
if shouldMimicClaudeCode {
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
+
+ body = stripMessageCacheControl(body)
+ body = addMessageCacheBreakpoints(body)
+ if rw := buildToolNameRewriteFromBody(body); rw != nil {
+ body = applyToolNameRewriteToBody(body, rw)
+ } else {
+ body = applyToolsLastCacheBreakpoint(body)
+ }
}
// Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。
@@ -8624,7 +9229,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
applyClaudeCodeMimicHeaders(req, false)
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
- requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
+ requiredBetas := append(claude.FullClaudeCodeMimicryBetas(), claude.BetaTokenCounting)
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet))
} else {
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
diff --git a/backend/internal/service/gateway_service_subscription_billing_test.go b/backend/internal/service/gateway_service_subscription_billing_test.go
new file mode 100644
index 00000000..42a81035
--- /dev/null
+++ b/backend/internal/service/gateway_service_subscription_billing_test.go
@@ -0,0 +1,85 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+)
+
+// TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier locks in the fix
+// that subscription-mode billing honours the group (and any user-specific) rate
+// multiplier — i.e. cmd.SubscriptionCost tracks ActualCost (= TotalCost *
+// RateMultiplier), not raw TotalCost.
+func TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier(t *testing.T) {
+ t.Parallel()
+
+ groupID := int64(7)
+ subID := int64(42)
+
+ tests := []struct {
+ name string
+ totalCost float64
+ actualCost float64
+ isSubscription bool
+ wantSub float64
+ wantBalance float64
+ }{
+ {
+ name: "subscription with 2x multiplier consumes 2x quota",
+ totalCost: 1.0,
+ actualCost: 2.0,
+ isSubscription: true,
+ wantSub: 2.0,
+ wantBalance: 0,
+ },
+ {
+ name: "subscription with 0.5x multiplier consumes 0.5x quota",
+ totalCost: 1.0,
+ actualCost: 0.5,
+ isSubscription: true,
+ wantSub: 0.5,
+ wantBalance: 0,
+ },
+ {
+ name: "free subscription (multiplier 0) consumes no quota",
+ totalCost: 1.0,
+ actualCost: 0,
+ isSubscription: true,
+ wantSub: 0,
+ wantBalance: 0,
+ },
+ {
+ name: "balance billing keeps using ActualCost (regression)",
+ totalCost: 1.0,
+ actualCost: 2.0,
+ isSubscription: false,
+ wantSub: 0,
+ wantBalance: 2.0,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ p := &postUsageBillingParams{
+ Cost: &CostBreakdown{TotalCost: tt.totalCost, ActualCost: tt.actualCost},
+ User: &User{ID: 1},
+ APIKey: &APIKey{ID: 2, GroupID: &groupID},
+ Account: &Account{ID: 3},
+ Subscription: &UserSubscription{ID: subID},
+ IsSubscriptionBill: tt.isSubscription,
+ }
+
+ cmd := buildUsageBillingCommand("req-1", nil, p)
+ if cmd == nil {
+ t.Fatal("buildUsageBillingCommand returned nil")
+ }
+ if cmd.SubscriptionCost != tt.wantSub {
+ t.Errorf("SubscriptionCost = %v, want %v", cmd.SubscriptionCost, tt.wantSub)
+ }
+ if cmd.BalanceCost != tt.wantBalance {
+ t.Errorf("BalanceCost = %v, want %v", cmd.BalanceCost, tt.wantBalance)
+ }
+ })
+ }
+}
diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go
index b1584827..ef09a882 100644
--- a/backend/internal/service/gateway_streaming_test.go
+++ b/backend/internal/service/gateway_streaming_test.go
@@ -4,9 +4,12 @@ package service
import (
"context"
+ "errors"
"io"
+ "net"
"net/http"
"net/http/httptest"
+ "syscall"
"testing"
"time"
@@ -218,3 +221,175 @@ func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) {
body := rec.Body.String()
require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件")
}
+
+// 上游中途读错误(如 HTTP/2 GOAWAY 触发的 unexpected EOF)发生在向客户端写入任何字节前:
+// 网关应返回 *UpstreamFailoverError 触发账号 failover/重试,而不是把错误事件直接发给客户端。
+func TestHandleStreamingResponse_StreamReadErrorBeforeOutput_TriggersFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ svc := newMinimalGatewayService()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: &streamReadCloser{err: io.ErrUnexpectedEOF},
+ }
+
+ result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
+
+ require.Error(t, err)
+ require.Nil(t, result, "失败移交场景下不应返回 streamingResult")
+
+ var failoverErr *UpstreamFailoverError
+ require.True(t, errors.As(err, &failoverErr), "未输出过字节时 stream read error 必须包成 UpstreamFailoverError,期望: %v", err)
+ require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
+ require.True(t, failoverErr.RetryableOnSameAccount, "GOAWAY 类错误应允许同账号重试")
+
+ // ResponseBody 必须是 Anthropic 标准 error 格式:
+ // 1) ExtractUpstreamErrorMessage 能正确从 error.message 提取消息(被 handleFailoverExhausted / ops 日志依赖)
+ // 2) error.type 标记为 upstream_disconnected
+ extractedMsg := ExtractUpstreamErrorMessage(failoverErr.ResponseBody)
+ require.NotEmpty(t, extractedMsg, "ExtractUpstreamErrorMessage 必须从 ResponseBody 取到非空 message,否则 ops 日志会丢失诊断信息")
+ require.Contains(t, extractedMsg, "upstream stream disconnected")
+ require.Contains(t, string(failoverErr.ResponseBody), `"type":"error"`)
+ require.Contains(t, string(failoverErr.ResponseBody), `"upstream_disconnected"`)
+
+ // 客户端应收不到任何 stream_read_error 事件,由 handler 层根据 failover 结果再决定
+ require.NotContains(t, rec.Body.String(), "stream_read_error")
+}
+
+// 上游已经发送过事件(c.Writer 已写过字节)后再发生读错误:
+// SSE 协议无 resume,网关只能透传 stream_read_error 错误事件给客户端,不能 failover。
+func TestHandleStreamingResponse_StreamReadErrorAfterOutput_PassesThrough(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ svc := newMinimalGatewayService()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+
+ // 第一次 Read 返回完整 SSE 事件让网关向 client 写入字节,第二次 Read 返回 EOF
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: &streamReadCloser{
+ payload: []byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n"),
+ err: io.ErrUnexpectedEOF,
+ },
+ }
+
+ result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "stream read error", "已开始流后应透传普通 stream read error")
+ require.NotNil(t, result, "透传场景下应返回已收集的 streamingResult")
+
+ // 不应被错误地包成 failover error
+ var failoverErr *UpstreamFailoverError
+ require.False(t, errors.As(err, &failoverErr), "已经向客户端写过字节时不能再 failover")
+
+ // 客户端必须收到 Anthropic 标准格式的 SSE error 事件,error.type=stream_read_error,
+ // error.message 含具体根因(让 SDK 能解析、UI 能显示具体错误)
+ body := rec.Body.String()
+ require.Contains(t, body, "event: error\n", "必须按 Anthropic SSE 标准发送 error 事件帧")
+ require.Contains(t, body, `"type":"error"`, "data 必须含 type:error 顶层字段(Anthropic 标准)")
+ require.Contains(t, body, `"stream_read_error"`, "error.type 必须为 stream_read_error")
+ require.Contains(t, body, "upstream stream disconnected", "error.message 必须包含具体根因,Claude Code 等客户端才能显示有效错误文案")
+}
+
+// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游
+// 服务器地址。sanitizeStreamError 必须剥离这些信息,避免基础设施拓扑通过
+// failover ResponseBody 或 SSE error 帧返回给客户端。
+func TestSanitizeStreamError_StripsNetworkAddresses(t *testing.T) {
+ src, err := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
+ require.NoError(t, err)
+ dst, err := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
+ require.NoError(t, err)
+
+ raw := &net.OpError{
+ Op: "read",
+ Net: "tcp",
+ Source: src,
+ Addr: dst,
+ Err: syscall.ECONNRESET,
+ }
+
+ // 前置:原始 Error() 确实包含会泄露的字段(避免测试在 Go 行为变化时静默通过)
+ require.Contains(t, raw.Error(), "10.0.0.1")
+ require.Contains(t, raw.Error(), "52.1.2.3")
+
+ got := sanitizeStreamError(raw)
+ require.NotContains(t, got, "10.0.0.1", "不得泄露内部源 IP")
+ require.NotContains(t, got, "54321", "不得泄露源端口")
+ require.NotContains(t, got, "52.1.2.3", "不得泄露上游目标 IP")
+ require.NotContains(t, got, "443", "不得泄露上游端口")
+ require.Equal(t, "connection reset by peer", got)
+}
+
+func TestSanitizeStreamError_KnownErrors(t *testing.T) {
+ cases := []struct {
+ name string
+ err error
+ want string
+ }{
+ {"unexpected EOF", io.ErrUnexpectedEOF, "unexpected EOF"},
+ {"EOF", io.EOF, "EOF"},
+ {"context canceled", context.Canceled, "canceled"},
+ {"deadline exceeded", context.DeadlineExceeded, "deadline exceeded"},
+ {"ECONNRESET 直接", syscall.ECONNRESET, "connection reset by peer"},
+ {"EPIPE", syscall.EPIPE, "broken pipe"},
+ {"ETIMEDOUT", syscall.ETIMEDOUT, "connection timed out"},
+ {"未识别错误兜底", errors.New("weird internal error"), "upstream connection error"},
+ {"nil 返回空串", nil, ""},
+ }
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ require.Equal(t, tc.want, sanitizeStreamError(tc.err))
+ })
+ }
+}
+
+// failover ResponseBody 必须用 sanitize 过的消息,避免泄露给客户端 / 写入 ops 日志
+// 时携带内部地址信息。
+func TestHandleStreamingResponse_FailoverBodyDoesNotLeakAddresses(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ svc := newMinimalGatewayService()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+
+ src, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:54321")
+ dst, _ := net.ResolveTCPAddr("tcp", "52.1.2.3:443")
+ netErr := &net.OpError{
+ Op: "read",
+ Net: "tcp",
+ Source: src,
+ Addr: dst,
+ Err: syscall.ECONNRESET,
+ }
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: &streamReadCloser{err: netErr},
+ }
+
+ _, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
+ require.Error(t, err)
+
+ var failoverErr *UpstreamFailoverError
+ require.True(t, errors.As(err, &failoverErr))
+
+ body := string(failoverErr.ResponseBody)
+ require.NotContains(t, body, "10.0.0.1", "failover ResponseBody 不得泄露内部源 IP")
+ require.NotContains(t, body, "54321")
+ require.NotContains(t, body, "52.1.2.3", "failover ResponseBody 不得泄露上游 IP")
+ require.NotContains(t, body, "443")
+ // 仍然包含可诊断的根因
+ require.Contains(t, body, "connection reset by peer")
+ require.Contains(t, body, "upstream stream disconnected")
+}
diff --git a/backend/internal/service/gateway_tool_rewrite.go b/backend/internal/service/gateway_tool_rewrite.go
new file mode 100644
index 00000000..c76cab62
--- /dev/null
+++ b/backend/internal/service/gateway_tool_rewrite.go
@@ -0,0 +1,313 @@
+package service
+
+import (
+ "fmt"
+ "hash/fnv"
+ "math/rand"
+ "sort"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+// toolNameRewriteKey 是 gin.Context 上存 ToolNameRewrite 映射的 key。
+// 请求阶段写入,响应阶段读取,用于 bytes 级逆向还原假名 → 真名。
+const toolNameRewriteKey = "claude_tool_name_rewrite"
+
+// staticToolNameRewrites 是"静态前缀映射",与 Parrot src/transform/cc_mimicry.py
+// TOOL_NAME_REWRITES 完全一致。只有以这些前缀开头的工具会被重写。
+var staticToolNameRewrites = map[string]string{
+ "sessions_": "cc_sess_",
+ "session_": "cc_ses_",
+}
+
+// fakeToolNamePrefixes 是"动态映射"的前缀池,与 Parrot _FAKE_PREFIXES 一致。
+// 当 tools 数量 > dynamicToolMapThreshold 时随机选用其中前缀生成可读假名。
+var fakeToolNamePrefixes = []string{
+ "analyze_", "compute_", "fetch_", "generate_", "lookup_", "modify_",
+ "process_", "query_", "render_", "resolve_", "sync_", "update_",
+ "validate_", "convert_", "extract_", "manage_", "monitor_", "parse_",
+ "review_", "search_", "transform_", "handle_", "invoke_", "notify_",
+}
+
+// dynamicToolMapThreshold 与 Parrot 一致:tools 数量超过 5 才启用动态映射。
+// 少量工具不需要混淆(一般是 Claude Code 自己的核心工具 bash/edit/read 等)。
+const dynamicToolMapThreshold = 5
+
+// ToolNameRewrite 是单次请求内的工具名混淆映射。
+// - Forward: real → fake,请求阶段在 body 上应用。
+// - Reverse: fake → real,响应阶段对每个 chunk 做 bytes.Replace 还原。
+//
+// ReverseOrdered 是按假名长度倒序的 (fake, real) 列表,用于防止短假名是长假名的
+// 子串时 bytes.Replace 先被吃掉(对齐 Parrot _restore_tool_names_in_chunk 的
+// `sorted(..., key=lambda x: len(x[1]), reverse=True)`)。
+type ToolNameRewrite struct {
+ Forward map[string]string
+ Reverse map[string]string
+ ReverseOrdered [][2]string
+}
+
+// buildDynamicToolMap 构造 tools 的动态假名映射。
+//
+// 与 Parrot _build_dynamic_tool_map 语义等价:
+// - tools 数量 ≤ dynamicToolMapThreshold 时返回 nil(不做动态映射,走静态 fallback)
+// - 同一组 tool_names 在同进程内映射稳定(保证 cache 命中)
+//
+// Parrot 用 `random.Random(hash(tuple(tool_names)))` 作 seed + shuffle 前缀池;
+// Go 无法字节级复刻 Python hash,但"稳定性"和"前缀池打散"两个不变量都保留:
+// 用 fnv64a(strings.Join(names, "\x00")) 作 seed 喂 math/rand.New。
+// 字节级不同不影响上游判定(Anthropic 不会验证我们的随机种子算法)。
+func buildDynamicToolMap(toolNames []string) map[string]string {
+ if len(toolNames) <= dynamicToolMapThreshold {
+ return nil
+ }
+ h := fnv.New64a()
+ for i, n := range toolNames {
+ if i > 0 {
+ _, _ = h.Write([]byte{0})
+ }
+ _, _ = h.Write([]byte(n))
+ }
+ rng := rand.New(rand.NewSource(int64(h.Sum64())))
+
+ available := make([]string, len(fakeToolNamePrefixes))
+ copy(available, fakeToolNamePrefixes)
+ rng.Shuffle(len(available), func(i, j int) { available[i], available[j] = available[j], available[i] })
+
+ mapping := make(map[string]string, len(toolNames))
+ for i, name := range toolNames {
+ prefix := available[i%len(available)]
+ headLen := 3
+ if len(name) < 3 {
+ headLen = len(name)
+ }
+ fake := fmt.Sprintf("%s%s%02d", prefix, name[:headLen], i)
+ mapping[name] = fake
+ }
+ return mapping
+}
+
+// sanitizeToolName 把真名转成假名。
+// 与 Parrot _sanitize_tool_name 语义一致:动态映射优先,再走静态前缀映射。
+func sanitizeToolName(name string, dynamic map[string]string) string {
+ if dynamic != nil {
+ if fake, ok := dynamic[name]; ok {
+ return fake
+ }
+ }
+ for prefix, replacement := range staticToolNameRewrites {
+ if strings.HasPrefix(name, prefix) {
+ return replacement + name[len(prefix):]
+ }
+ }
+ return name
+}
+
+// shouldMimicToolName 指示某个 tool 是否需要重命名。
+// server tool(type != "" 且不是 "function" / "custom")是 Anthropic 协议语义的一部分,
+// 比如 "web_search_20250305" / "computer_20250124";误改会导致上游拒绝。
+func shouldMimicToolName(toolType string) bool {
+ if toolType == "" || toolType == "function" || toolType == "custom" {
+ return true
+ }
+ return false
+}
+
+// buildToolNameRewriteFromBody 扫描 body 的 tools[*].name,构造 ToolNameRewrite
+// 并返回它。若不需要混淆(tools 数量不足 + 没有匹配静态前缀的工具)返回 nil。
+//
+// 注意:只扫描,不改 body。真正的 body 改写在 applyToolNameRewriteToBody。
+func buildToolNameRewriteFromBody(body []byte) *ToolNameRewrite {
+ tools := gjson.GetBytes(body, "tools")
+ if !tools.IsArray() {
+ return nil
+ }
+
+ mimicableNames := make([]string, 0)
+ toolsArr := tools.Array()
+ for _, t := range toolsArr {
+ if !shouldMimicToolName(t.Get("type").String()) {
+ continue
+ }
+ name := t.Get("name").String()
+ if name == "" {
+ continue
+ }
+ mimicableNames = append(mimicableNames, name)
+ }
+
+ dynamic := buildDynamicToolMap(mimicableNames)
+
+ rw := &ToolNameRewrite{
+ Forward: make(map[string]string),
+ Reverse: make(map[string]string),
+ }
+ for _, name := range mimicableNames {
+ fake := sanitizeToolName(name, dynamic)
+ if fake == name {
+ continue
+ }
+ rw.Forward[name] = fake
+ rw.Reverse[fake] = name
+ }
+ if len(rw.Forward) == 0 {
+ return nil
+ }
+
+ rw.ReverseOrdered = make([][2]string, 0, len(rw.Reverse))
+ for fake, real := range rw.Reverse {
+ rw.ReverseOrdered = append(rw.ReverseOrdered, [2]string{fake, real})
+ }
+ sort.SliceStable(rw.ReverseOrdered, func(i, j int) bool {
+ return len(rw.ReverseOrdered[i][0]) > len(rw.ReverseOrdered[j][0])
+ })
+
+ return rw
+}
+
+// applyToolNameRewriteToBody 把已构造的 ToolNameRewrite 应用到 body 上:
+// - 改写 $.tools[*].name(仅对 shouldMimicToolName 通过的 tool)
+// - 在 $.tools[last].cache_control 上打 ephemeral 缓存断点(Parrot 行为对齐,
+// ttl 客户端已有则透传,否则默认 claude.DefaultCacheControlTTL)
+// - 改写 $.tool_choice.name(仅当 $.tool_choice.type == "tool")
+//
+// 历史 $.messages[*].content[*].name(tool_use)不在请求侧改写——这与 Parrot 一致;
+// 响应侧 bytes.Replace 会连带还原它们。
+func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte {
+ if rw == nil || len(rw.Forward) == 0 {
+ body = applyToolsLastCacheBreakpoint(body)
+ return body
+ }
+
+ tools := gjson.GetBytes(body, "tools")
+ if tools.IsArray() {
+ idx := -1
+ tools.ForEach(func(_, t gjson.Result) bool {
+ idx++
+ if !shouldMimicToolName(t.Get("type").String()) {
+ return true
+ }
+ name := t.Get("name").String()
+ if name == "" {
+ return true
+ }
+ fake, ok := rw.Forward[name]
+ if !ok {
+ return true
+ }
+ if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.name", idx), fake); err == nil {
+ body = next
+ }
+ return true
+ })
+ }
+
+ if tc := gjson.GetBytes(body, "tool_choice"); tc.Exists() && tc.Get("type").String() == "tool" {
+ name := tc.Get("name").String()
+ if fake, ok := rw.Forward[name]; ok {
+ if next, err := sjson.SetBytes(body, "tool_choice.name", fake); err == nil {
+ body = next
+ }
+ }
+ }
+
+ body = applyToolsLastCacheBreakpoint(body)
+ return body
+}
+
+// applyToolsLastCacheBreakpoint 在 tools 数组最后一个工具上注入 cache_control
+// 断点,对齐 Parrot `tools[-1]["cache_control"] = {"type":"ephemeral","ttl":"1h"}`
+// 行为,但 ttl 按本仓规则:
+// - 客户端已为该 tool 显式设置 cache_control.ttl → 完全透传不覆盖
+// - 否则注入 {"type":"ephemeral","ttl": claude.DefaultCacheControlTTL}
+//
+// 纯副作用函数,tools 不存在或为空数组时 no-op。
+func applyToolsLastCacheBreakpoint(body []byte) []byte {
+ tools := gjson.GetBytes(body, "tools")
+ if !tools.IsArray() {
+ return body
+ }
+ arr := tools.Array()
+ if len(arr) == 0 {
+ return body
+ }
+ lastIdx := len(arr) - 1
+ existingCC := arr[lastIdx].Get("cache_control")
+
+ if existingCC.Exists() && existingCC.Get("ttl").String() != "" {
+ return body
+ }
+
+ if existingCC.Exists() {
+ if next, err := sjson.SetBytes(body, fmt.Sprintf("tools.%d.cache_control.ttl", lastIdx), claude.DefaultCacheControlTTL); err == nil {
+ body = next
+ }
+ return body
+ }
+
+ raw := fmt.Sprintf(`{"type":"ephemeral","ttl":%q}`, claude.DefaultCacheControlTTL)
+ if next, err := sjson.SetRawBytes(body, fmt.Sprintf("tools.%d.cache_control", lastIdx), []byte(raw)); err == nil {
+ body = next
+ }
+ return body
+}
+
+// restoreToolNamesInBytes 对 bytes chunk 做逆向还原:假名 → 真名。
+// 按 ReverseOrdered 的假名长度倒序逐个 bytes.Replace,防止子串冲突
+// (与 Parrot _restore_tool_names_in_chunk 的 sorted(..., reverse=True) 等价)。
+// 再做静态前缀还原(cc_sess_ → sessions_ / cc_ses_ → session_)。
+//
+// rw 可为 nil;nil 时仍会做静态前缀还原。
+func restoreToolNamesInBytes(data []byte, rw *ToolNameRewrite) []byte {
+ if rw != nil {
+ for _, pair := range rw.ReverseOrdered {
+ fake, real := pair[0], pair[1]
+ if fake == "" || fake == real {
+ continue
+ }
+ data = replaceAllBytes(data, fake, real)
+ }
+ }
+ for prefix, replacement := range staticToolNameRewrites {
+ data = replaceAllBytes(data, replacement, prefix)
+ }
+ return data
+}
+
+// replaceAllBytes 是 bytes.ReplaceAll 的便捷封装,避免每个调用点各自做 []byte 转换。
+func replaceAllBytes(data []byte, from, to string) []byte {
+ if len(data) == 0 || from == to || !strings.Contains(string(data), from) {
+ return data
+ }
+ return []byte(strings.ReplaceAll(string(data), from, to))
+}
+
+// toolNameRewriteFromContext 从 gin.Context 取出请求阶段保存的工具名映射。
+// 找不到(c==nil 或 key 不存在或类型不对)时返回 nil;调用方必须能处理 nil。
+func toolNameRewriteFromContext(c interface {
+ Get(string) (any, bool)
+}) *ToolNameRewrite {
+ if c == nil {
+ return nil
+ }
+ raw, ok := c.Get(toolNameRewriteKey)
+ if !ok || raw == nil {
+ return nil
+ }
+ rw, _ := raw.(*ToolNameRewrite)
+ return rw
+}
+
+// reverseToolNamesIfPresent 是响应侧 5 处注入点的统一封装:从 c 取出 mapping
+// 并对 chunk 做 bytes 级假名→真名替换。c 没有 mapping 时仍会做静态前缀还原。
+func reverseToolNamesIfPresent(c interface {
+ Get(string) (any, bool)
+}, chunk []byte) []byte {
+ rw := toolNameRewriteFromContext(c)
+ if rw == nil && len(staticToolNameRewrites) == 0 {
+ return chunk
+ }
+ return restoreToolNamesInBytes(chunk, rw)
+}
diff --git a/backend/internal/service/gateway_tool_rewrite_test.go b/backend/internal/service/gateway_tool_rewrite_test.go
new file mode 100644
index 00000000..8f0e3939
--- /dev/null
+++ b/backend/internal/service/gateway_tool_rewrite_test.go
@@ -0,0 +1,185 @@
+package service
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestBuildDynamicToolMap_BelowThreshold(t *testing.T) {
+ // Parrot 行为:tools 数量 ≤ 5 时不做动态映射。
+ names := []string{"bash", "edit", "read", "write", "search"}
+ require.Nil(t, buildDynamicToolMap(names))
+}
+
+func TestBuildDynamicToolMap_AboveThresholdIsStable(t *testing.T) {
+ // Parrot 不变量:同一组 tool_names 在同进程内映射稳定(保证 cache 命中)。
+ names := []string{"alpha", "beta", "gamma", "delta", "epsilon", "zeta"}
+ a := buildDynamicToolMap(names)
+ b := buildDynamicToolMap(names)
+ require.NotNil(t, a)
+ require.Equal(t, a, b, "same input tool_names must yield identical mapping")
+ require.Len(t, a, 6)
+ for _, name := range names {
+ require.Contains(t, a, name)
+ require.NotEqual(t, name, a[name])
+ }
+}
+
+func TestSanitizeToolName_StaticPrefix(t *testing.T) {
+ require.Equal(t, "cc_sess_list", sanitizeToolName("sessions_list", nil))
+ require.Equal(t, "cc_ses_get", sanitizeToolName("session_get", nil))
+ require.Equal(t, "bash", sanitizeToolName("bash", nil))
+}
+
+func TestSanitizeToolName_DynamicTakesPrecedence(t *testing.T) {
+ dyn := map[string]string{"sessions_list": "analyze_ses00"}
+ got := sanitizeToolName("sessions_list", dyn)
+ require.Equal(t, "analyze_ses00", got, "dynamic mapping wins over static prefix")
+}
+
+func TestRestoreToolNamesInBytes_LongestFirst(t *testing.T) {
+ // 当假名 "abc_12" 是另一个更长假名的子串(真实场景极少但算法必须防御)时,
+ // 长的必须先替换。本测试用显式构造的映射来验证排序不变量。
+ rw := &ToolNameRewrite{
+ Forward: map[string]string{"foo": "abc_12", "bar": "abc_12_ext"},
+ Reverse: map[string]string{"abc_12": "foo", "abc_12_ext": "bar"},
+ }
+ // 手工构造 ReverseOrdered:长的在前
+ rw.ReverseOrdered = [][2]string{
+ {"abc_12_ext", "bar"},
+ {"abc_12", "foo"},
+ }
+ data := []byte(`{"tool":"abc_12_ext","other":"abc_12"}`)
+ restored := string(restoreToolNamesInBytes(data, rw))
+ require.Equal(t, `{"tool":"bar","other":"foo"}`, restored)
+}
+
+func TestRestoreToolNamesInBytes_StaticPrefixRollback(t *testing.T) {
+ data := []byte(`{"name":"sessions_list","id":"cc_ses_xyz"}`)
+ got := string(restoreToolNamesInBytes(data, nil))
+ require.Equal(t, `{"name":"sessions_list","id":"session_xyz"}`, got)
+}
+
+func TestApplyToolNameRewriteToBody_RenamesToolsAndToolChoice(t *testing.T) {
+ body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"session_get","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"tool_choice":{"type":"tool","name":"sessions_list"}}`)
+ rw := buildToolNameRewriteFromBody(body)
+ require.NotNil(t, rw)
+ require.Contains(t, rw.Forward, "sessions_list")
+ require.Contains(t, rw.Forward, "session_get")
+ // web_search is a server tool, not rewritten
+ require.NotContains(t, rw.Forward, "web_search")
+
+ out := applyToolNameRewriteToBody(body, rw)
+
+ // tools[0].name and tools[1].name rewritten; tools[2].name untouched
+ require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
+ require.Equal(t, "cc_ses_get", gjson.GetBytes(out, "tools.1.name").String())
+ require.Equal(t, "web_search", gjson.GetBytes(out, "tools.2.name").String())
+
+ // tool_choice.name rewritten
+ require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tool_choice.name").String())
+ require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String())
+}
+
+func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) {
+ body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`)
+ out := applyToolsLastCacheBreakpoint(body)
+ require.Equal(t, "ephemeral", gjson.GetBytes(out, "tools.1.cache_control.type").String())
+ require.Equal(t, "5m", gjson.GetBytes(out, "tools.1.cache_control.ttl").String())
+ // First tool untouched
+ require.False(t, gjson.GetBytes(out, "tools.0.cache_control").Exists())
+}
+
+func TestApplyToolsLastCacheBreakpoint_PassesThroughClientTTL(t *testing.T) {
+ body := []byte(`{"tools":[{"name":"a","input_schema":{},"cache_control":{"type":"ephemeral","ttl":"1h"}}]}`)
+ out := applyToolsLastCacheBreakpoint(body)
+ // User-provided ttl must be preserved.
+ require.Equal(t, "1h", gjson.GetBytes(out, "tools.0.cache_control.ttl").String())
+}
+
+func TestStripMessageCacheControl(t *testing.T) {
+ body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral"}}]}]}`)
+ out := stripMessageCacheControl(body)
+ require.False(t, gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists())
+}
+
+func TestAddMessageCacheBreakpoints_LastMessageOnly(t *testing.T) {
+ body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
+ out := addMessageCacheBreakpoints(body)
+ require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
+ require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
+}
+
+func TestAddMessageCacheBreakpoints_SecondToLastUserTurn(t *testing.T) {
+ // Parrot 不变量:messages ≥ 4 时才打第二个断点,且位置是"倒数第二个 user turn"。
+ body := []byte(`{"messages":[
+ {"role":"user","content":[{"type":"text","text":"q1"}]},
+ {"role":"assistant","content":[{"type":"text","text":"a1"}]},
+ {"role":"user","content":[{"type":"text","text":"q2"}]},
+ {"role":"assistant","content":[{"type":"text","text":"a2"}]}
+ ]}`)
+ out := addMessageCacheBreakpoints(body)
+ // 最后一条 assistant 被打断点
+ require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.3.content.0.cache_control.type").String())
+ // 倒数第二个 user turn = index 0(唯一另一个 user)
+ require.Equal(t, "ephemeral", gjson.GetBytes(out, "messages.0.content.0.cache_control.type").String())
+ // 其他不打断点
+ require.False(t, gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists())
+ require.False(t, gjson.GetBytes(out, "messages.2.content.0.cache_control").Exists())
+}
+
+func TestAddMessageCacheBreakpoints_StringContentPromoted(t *testing.T) {
+ body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
+ out := addMessageCacheBreakpoints(body)
+ // content 升级成数组
+ require.True(t, gjson.GetBytes(out, "messages.0.content").IsArray())
+ require.Equal(t, "text", gjson.GetBytes(out, "messages.0.content.0.type").String())
+ require.Equal(t, "hi", gjson.GetBytes(out, "messages.0.content.0.text").String())
+ require.Equal(t, "5m", gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").String())
+}
+
+func TestBuildToolNameRewriteFromBody_ReverseOrderedByLengthDesc(t *testing.T) {
+ // 超过阈值触发动态映射,验证 ReverseOrdered 按假名长度倒序排列
+ body := []byte(`{"tools":[
+ {"name":"t1","input_schema":{}},
+ {"name":"t2","input_schema":{}},
+ {"name":"t3","input_schema":{}},
+ {"name":"t4","input_schema":{}},
+ {"name":"t5","input_schema":{}},
+ {"name":"t6","input_schema":{}}
+ ]}`)
+ rw := buildToolNameRewriteFromBody(body)
+ require.NotNil(t, rw)
+ require.NotEmpty(t, rw.ReverseOrdered)
+ for i := 1; i < len(rw.ReverseOrdered); i++ {
+ require.GreaterOrEqual(t, len(rw.ReverseOrdered[i-1][0]), len(rw.ReverseOrdered[i][0]),
+ "ReverseOrdered must be sorted by fake-name length descending")
+ }
+}
+
+func TestRestoreToolNamesInBytes_NoMapping_NoStaticMatch_IsNoop(t *testing.T) {
+ data := []byte("plain text without any tool names")
+ require.Equal(t, string(data), string(restoreToolNamesInBytes(data, nil)))
+}
+
+// Ensure the fake name format follows Parrot's "{prefix}{name[:3]}{i:02d}".
+func TestBuildDynamicToolMap_FakeNameShape(t *testing.T) {
+ names := []string{"alphabet", "bravo", "charlie", "delta", "echo", "foxtrot"}
+ m := buildDynamicToolMap(names)
+ require.NotNil(t, m)
+ for _, name := range names {
+ fake, ok := m[name]
+ require.True(t, ok)
+ // fake = prefix + head3 + "%02d"
+ // ends with two decimal digits
+ require.Regexp(t, `^[a-z]+_[a-z0-9]{1,3}\d{2}$`, fake)
+ head := name
+ if len(head) > 3 {
+ head = head[:3]
+ }
+ require.True(t, strings.Contains(fake, head), "fake %q should contain head3 %q of %q", fake, head, name)
+ }
+}
diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go
index 7a24071b..ea0c0d7d 100644
--- a/backend/internal/service/gemini_messages_compat_service.go
+++ b/backend/internal/service/gemini_messages_compat_service.go
@@ -515,6 +515,10 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
}
// Code Assist OAuth tokens often lack AI Studio scopes for models listing.
return 3
+ case AccountTypeServiceAccount:
+ // Vertex service accounts use aiplatform.googleapis.com, not the AI Studio
+ // endpoint (generativelanguage.googleapis.com), so they cannot serve these requests.
+ return 999
default:
return 10
}
@@ -579,7 +583,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model
mappedModel := req.Model
- if account.Type == AccountTypeAPIKey {
+ if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(req.Model)
}
@@ -712,6 +716,36 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader = "x-request-id"
+ case AccountTypeServiceAccount:
+ buildReq = func(ctx context.Context) (*http.Request, string, error) {
+ if s.tokenProvider == nil {
+ return nil, "", errors.New("gemini token provider not configured")
+ }
+ accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, "", err
+ }
+
+ action := "generateContent"
+ if req.Stream {
+ action = "streamGenerateContent"
+ }
+ fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, req.Stream)
+ if err != nil {
+ return nil, "", err
+ }
+
+ restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
+ upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
+ if err != nil {
+ return nil, "", err
+ }
+ upstreamReq.Header.Set("Content-Type", "application/json")
+ upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
+ return upstreamReq, "x-request-id", nil
+ }
+ requestIDHeader = "x-request-id"
+
default:
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
}
@@ -1094,7 +1128,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
body = ensureGeminiFunctionCallThoughtSignatures(body)
mappedModel := originalModel
- if account.Type == AccountTypeAPIKey {
+ if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
mappedModel = account.GetMappedModel(originalModel)
}
@@ -1213,6 +1247,31 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader = "x-request-id"
+ case AccountTypeServiceAccount:
+ buildReq = func(ctx context.Context) (*http.Request, string, error) {
+ if s.tokenProvider == nil {
+ return nil, "", errors.New("gemini token provider not configured")
+ }
+ accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, "", err
+ }
+
+ fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, upstreamAction, useUpstreamStream)
+ if err != nil {
+ return nil, "", err
+ }
+
+ upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(body))
+ if err != nil {
+ return nil, "", err
+ }
+ upstreamReq.Header.Set("Content-Type", "application/json")
+ upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
+ return upstreamReq, "x-request-id", nil
+ }
+ requestIDHeader = "x-request-id"
+
default:
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Unsupported account type: "+account.Type)
}
diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go
index 7add3460..172b9411 100644
--- a/backend/internal/service/gemini_token_provider.go
+++ b/backend/internal/service/gemini_token_provider.go
@@ -15,7 +15,7 @@ const (
geminiTokenCacheSkew = 5 * time.Minute
)
-// GeminiTokenProvider manages access_token for Gemini OAuth accounts.
+// GeminiTokenProvider manages access_token for Gemini OAuth and Vertex service account accounts.
type GeminiTokenProvider struct {
accountRepo AccountRepository
tokenCache GeminiTokenCache
@@ -53,8 +53,11 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
if account == nil {
return "", errors.New("account is nil")
}
- if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
- return "", errors.New("not a gemini oauth account")
+ if account.Platform != PlatformGemini || (account.Type != AccountTypeOAuth && account.Type != AccountTypeServiceAccount) {
+ return "", errors.New("not a gemini oauth or service account")
+ }
+ if account.Type == AccountTypeServiceAccount {
+ return p.getServiceAccountAccessToken(ctx, account)
}
cacheKey := GeminiTokenCacheKey(account)
@@ -168,7 +171,16 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
+func (p *GeminiTokenProvider) getServiceAccountAccessToken(ctx context.Context, account *Account) (string, error) {
+ return getVertexServiceAccountAccessToken(ctx, p.tokenCache, account)
+}
+
func GeminiTokenCacheKey(account *Account) string {
+ if account != nil && account.Type == AccountTypeServiceAccount {
+ if key, err := parseVertexServiceAccountKey(account); err == nil {
+ return vertexServiceAccountCacheKey(account, key)
+ }
+ }
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "gemini:" + projectID
diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go
index 12262613..bb4c5aa1 100644
--- a/backend/internal/service/group.go
+++ b/backend/internal/service/group.go
@@ -59,6 +59,10 @@ type Group struct {
DefaultMappedModel string
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
+ // RPMLimit 分组级每分钟请求数上限(0 = 不限制)。
+ // 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。
+ RPMLimit int
+
CreatedAt time.Time
UpdatedAt time.Time
@@ -76,10 +80,6 @@ func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
-func (g *Group) IsFreeSubscription() bool {
- return g.IsSubscriptionType() && g.RateMultiplier == 0
-}
-
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go
index 3d706508..665922e3 100644
--- a/backend/internal/service/identity_service.go
+++ b/backend/internal/service/identity_service.go
@@ -26,7 +26,7 @@ var (
// 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{
- UserAgent: "claude-cli/2.1.22 (external, cli)",
+ UserAgent: "claude-cli/2.1.92 (external, cli)",
StainlessLang: "js",
StainlessPackageVersion: "0.70.0",
StainlessOS: "Linux",
diff --git a/backend/internal/service/model_pricing_resolver.go b/backend/internal/service/model_pricing_resolver.go
index b7ca4cb7..58089776 100644
--- a/backend/internal/service/model_pricing_resolver.go
+++ b/backend/internal/service/model_pricing_resolver.go
@@ -61,6 +61,25 @@ type PricingInput struct {
// 1. 获取基础定价(LiteLLM → Fallback)
// 2. 如果指定了 GroupID,查找渠道定价并覆盖
func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput) *ResolvedPricing {
+ var chPricing *ChannelModelPricing
+ if input.GroupID != nil && r.channelService != nil {
+ chPricing = r.channelService.GetChannelModelPricing(ctx, *input.GroupID, input.Model)
+ if chPricing != nil {
+ mode := chPricing.BillingMode
+ if mode == "" {
+ mode = BillingModeToken
+ }
+ if mode == BillingModePerRequest || mode == BillingModeImage {
+ resolved := &ResolvedPricing{
+ Mode: mode,
+ Source: PricingSourceChannel,
+ }
+ r.applyRequestTierOverrides(chPricing, resolved)
+ return resolved
+ }
+ }
+ }
+
// 1. 获取基础定价
basePricing, source := r.resolveBasePricing(input.Model)
@@ -72,7 +91,10 @@ func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput)
}
// 2. 如果有 GroupID,尝试渠道覆盖
- if input.GroupID != nil {
+ if chPricing != nil {
+ resolved.Source = PricingSourceChannel
+ r.applyTokenOverrides(chPricing, resolved)
+ } else if input.GroupID != nil {
r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved)
}
diff --git a/backend/internal/service/model_pricing_resolver_test.go b/backend/internal/service/model_pricing_resolver_test.go
index 905c4df6..4548c1d5 100644
--- a/backend/internal/service/model_pricing_resolver_test.go
+++ b/backend/internal/service/model_pricing_resolver_test.go
@@ -184,7 +184,7 @@ func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelP
return map[int64]string{groupID: "anthropic"}, nil
},
}
- cs := NewChannelService(repo, nil)
+ cs := NewChannelService(repo, nil, nil, nil)
bs := newTestBillingServiceForResolver()
return NewModelPricingResolver(cs, bs)
}
@@ -517,7 +517,7 @@ func TestResolve_WithChannelOverride_CacheError(t *testing.T) {
return nil, errors.New("database unavailable")
},
}
- cs := NewChannelService(repo, nil)
+ cs := NewChannelService(repo, nil, nil, nil)
bs := newTestBillingServiceForResolver()
r := NewModelPricingResolver(cs, bs)
diff --git a/backend/internal/service/openai_403_counter.go b/backend/internal/service/openai_403_counter.go
new file mode 100644
index 00000000..5ba3e195
--- /dev/null
+++ b/backend/internal/service/openai_403_counter.go
@@ -0,0 +1,11 @@
+package service
+
+import "context"
+
+// OpenAI403CounterCache 追踪 OpenAI 账号连续 403 失败次数。
+type OpenAI403CounterCache interface {
+ // IncrementOpenAI403Count 原子递增 403 计数并返回当前值。
+ IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error)
+ // ResetOpenAI403Count 成功后清零计数器。
+ ResetOpenAI403Count(ctx context.Context, accountID int64) error
+}
diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go
index 6c09e354..7a0a6636 100644
--- a/backend/internal/service/openai_account_scheduler.go
+++ b/backend/internal/service/openai_account_scheduler.go
@@ -3,7 +3,6 @@ package service
import (
"container/heap"
"context"
- "errors"
"fmt"
"hash/fnv"
"math"
@@ -13,22 +12,40 @@ import (
"sync"
"sync/atomic"
"time"
+
+ "golang.org/x/sync/singleflight"
)
const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance"
+ openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled"
)
+const (
+ openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second
+ openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second
+)
+
+type cachedOpenAIAdvancedSchedulerSetting struct {
+ enabled bool
+ expiresAt int64
+}
+
+var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting
+var openAIAdvancedSchedulerSettingSF singleflight.Group
+
type OpenAIAccountScheduleRequest struct {
- GroupID *int64
- SessionHash string
- StickyAccountID int64
- PreviousResponseID string
- RequestedModel string
- RequiredTransport OpenAIUpstreamTransport
- ExcludedIDs map[int64]struct{}
+ GroupID *int64
+ SessionHash string
+ StickyAccountID int64
+ PreviousResponseID string
+ RequestedModel string
+ RequiredTransport OpenAIUpstreamTransport
+ RequiredImageCapability OpenAIImagesCapability
+ RequireCompact bool
+ ExcludedIDs map[int64]struct{}
}
type OpenAIAccountScheduleDecision struct {
@@ -241,12 +258,16 @@ func (s *defaultOpenAIAccountScheduler) Select(
previousResponseID,
req.RequestedModel,
req.ExcludedIDs,
+ req.RequireCompact,
)
if err != nil {
return nil, decision, err
}
if selection != nil && selection.Account != nil {
if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
selection = nil
}
}
@@ -324,15 +345,15 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
- if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
+ if !s.isAccountRequestCompatible(account, req) {
return nil, nil
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
- account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel)
- if account == nil {
+ account = s.service.recheckSelectedOpenAIAccountFromDB(ctx, account, req.RequestedModel, req.RequireCompact)
+ if account == nil || !s.isAccountTransportCompatible(account, req.RequiredTransport) {
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
@@ -573,7 +594,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
return nil, 0, 0, 0, err
}
if len(accounts) == 0 {
- return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
+ return nil, 0, 0, 0, noAvailableOpenAISelectionError(req.RequestedModel, false)
}
// require_privacy_set: 获取分组信息
@@ -600,7 +621,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
- if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
+ if !s.isAccountRequestCompatible(account, req) {
continue
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
@@ -613,7 +634,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
})
}
if len(filtered) == 0 {
- return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
+ return nil, 0, 0, 0, noAvailableOpenAISelectionError(req.RequestedModel, false)
}
loadMap := map[int64]*AccountLoadInfo{}
@@ -623,45 +644,14 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
}
- minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
- maxWaiting := 1
- loadRateSum := 0.0
- loadRateSumSquares := 0.0
- minTTFT, maxTTFT := 0.0, 0.0
- hasTTFTSample := false
- candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
+ allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
- if account.Priority < minPriority {
- minPriority = account.Priority
- }
- if account.Priority > maxPriority {
- maxPriority = account.Priority
- }
- if loadInfo.WaitingCount > maxWaiting {
- maxWaiting = loadInfo.WaitingCount
- }
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
- if hasTTFT && ttft > 0 {
- if !hasTTFTSample {
- minTTFT, maxTTFT = ttft, ttft
- hasTTFTSample = true
- } else {
- if ttft < minTTFT {
- minTTFT = ttft
- }
- if ttft > maxTTFT {
- maxTTFT = ttft
- }
- }
- }
- loadRate := float64(loadInfo.LoadRate)
- loadRateSum += loadRate
- loadRateSumSquares += loadRate * loadRate
- candidates = append(candidates, openAIAccountCandidateScore{
+ allCandidates = append(allCandidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
@@ -669,53 +659,183 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
hasTTFT: hasTTFT,
})
}
- loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
- weights := s.service.openAIWSSchedulerWeights()
- for i := range candidates {
- item := &candidates[i]
- priorityFactor := 1.0
- if maxPriority > minPriority {
- priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
+ // Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用
+ // 时作为最后兜底(snapshot 可能已陈旧)。
+ candidates := allCandidates
+ staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
+ if req.RequireCompact {
+ candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
+ for _, candidate := range allCandidates {
+ if openAICompactSupportTier(candidate.account) == 0 {
+ staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
+ continue
+ }
+ candidates = append(candidates, candidate)
}
- loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
- queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
- errorFactor := 1 - clamp01(item.errorRate)
- ttftFactor := 0.5
- if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
- ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
+ if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 {
+ return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
}
-
- item.score = weights.Priority*priorityFactor +
- weights.Load*loadFactor +
- weights.Queue*queueFactor +
- weights.ErrorRate*errorFactor +
- weights.TTFT*ttftFactor
}
- topK := s.service.openAIWSLBTopK()
- if topK > len(candidates) {
- topK = len(candidates)
- }
- if topK <= 0 {
- topK = 1
- }
- rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
- selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
+ candidateCount := len(candidates)
+ loadSkew := 0.0
+ if len(candidates) > 0 {
+ minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
+ maxWaiting := 1
+ loadRateSum := 0.0
+ loadRateSumSquares := 0.0
+ minTTFT, maxTTFT := 0.0, 0.0
+ hasTTFTSample := false
+ for _, candidate := range candidates {
+ if candidate.account.Priority < minPriority {
+ minPriority = candidate.account.Priority
+ }
+ if candidate.account.Priority > maxPriority {
+ maxPriority = candidate.account.Priority
+ }
+ if candidate.loadInfo.WaitingCount > maxWaiting {
+ maxWaiting = candidate.loadInfo.WaitingCount
+ }
+ if candidate.hasTTFT && candidate.ttft > 0 {
+ if !hasTTFTSample {
+ minTTFT, maxTTFT = candidate.ttft, candidate.ttft
+ hasTTFTSample = true
+ } else {
+ if candidate.ttft < minTTFT {
+ minTTFT = candidate.ttft
+ }
+ if candidate.ttft > maxTTFT {
+ maxTTFT = candidate.ttft
+ }
+ }
+ }
+ loadRate := float64(candidate.loadInfo.LoadRate)
+ loadRateSum += loadRate
+ loadRateSumSquares += loadRate * loadRate
+ }
+ loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
+ weights := s.service.openAIWSSchedulerWeights()
+ for i := range candidates {
+ item := &candidates[i]
+ priorityFactor := 1.0
+ if maxPriority > minPriority {
+ priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
+ }
+ loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
+ queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
+ errorFactor := 1 - clamp01(item.errorRate)
+ ttftFactor := 0.5
+ if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
+ ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
+ }
+
+ item.score = weights.Priority*priorityFactor +
+ weights.Load*loadFactor +
+ weights.Queue*queueFactor +
+ weights.ErrorRate*errorFactor +
+ weights.TTFT*ttftFactor
+ }
+ }
+
+ topK := 0
+ if len(candidates) > 0 {
+ topK = s.service.openAIWSLBTopK()
+ if topK > len(candidates) {
+ topK = len(candidates)
+ }
+ if topK <= 0 {
+ topK = 1
+ }
+ }
+
+ buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
+ if len(pool) == 0 || topK <= 0 {
+ return nil
+ }
+ groupTopK := topK
+ if groupTopK > len(pool) {
+ groupTopK = len(pool)
+ }
+ ranked := selectTopKOpenAICandidates(pool, groupTopK)
+ return buildOpenAIWeightedSelectionOrder(ranked, req)
+ }
+ sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
+ if len(pool) == 0 {
+ return nil
+ }
+ ordered := append([]openAIAccountCandidateScore(nil), pool...)
+ sort.SliceStable(ordered, func(i, j int) bool {
+ a, b := ordered[i], ordered[j]
+ if a.account.Priority != b.account.Priority {
+ return a.account.Priority < b.account.Priority
+ }
+ if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
+ return a.loadInfo.LoadRate < b.loadInfo.LoadRate
+ }
+ if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
+ return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
+ }
+ switch {
+ case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
+ return true
+ case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
+ return false
+ case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
+ return false
+ default:
+ return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
+ }
+ })
+ return ordered
+ }
+
+ selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates))
+ if req.RequireCompact {
+ supported := make([]openAIAccountCandidateScore, 0, len(candidates))
+ unknown := make([]openAIAccountCandidateScore, 0, len(candidates))
+ for _, candidate := range candidates {
+ switch openAICompactSupportTier(candidate.account) {
+ case 2:
+ supported = append(supported, candidate)
+ case 1:
+ unknown = append(unknown, candidate)
+ }
+ }
+ if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil {
+ return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
+ }
+ selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
+ selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
+ if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
+ selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...)
+ }
+ } else {
+ selectionOrder = buildSelectionOrder(candidates)
+ }
+ if len(selectionOrder) == 0 {
+ return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0)
+ }
+
+ compactBlocked := false
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
- fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
- if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
+ fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
+ if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
continue
}
- fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel)
- if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
+ fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
+ if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
+ continue
+ }
+ if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
+ compactBlocked = true
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil {
- return nil, len(candidates), topK, loadSkew, acquireErr
+ return nil, candidateCount, topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
@@ -725,15 +845,23 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
- }, len(candidates), topK, loadSkew, nil
+ }, candidateCount, topK, loadSkew, nil
}
}
cfg := s.service.schedulingConfig()
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
for _, candidate := range selectionOrder {
- fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel)
- if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) {
+ fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
+ if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
+ continue
+ }
+ fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
+ if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
+ continue
+ }
+ if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
+ compactBlocked = true
continue
}
return &AccountSelectionResult{
@@ -744,21 +872,30 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
- }, len(candidates), topK, loadSkew, nil
+ }, candidateCount, topK, loadSkew, nil
}
- return nil, len(candidates), topK, loadSkew, ErrNoAvailableAccounts
+ return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, compactBlocked)
}
func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
- // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
return true
}
- if s == nil || s.service == nil || account == nil {
+ if s == nil || s.service == nil {
return false
}
- return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
+ return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
+}
+
+func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool {
+ if account == nil {
+ return false
+ }
+ if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
+ return false
+ }
+ return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
}
func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
@@ -805,10 +942,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler
return snapshot
}
-func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
+func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository {
+ if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil {
+ return nil
+ }
+ return s.rateLimitService.settingService.settingRepo
+}
+
+func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool {
+ if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
+ if time.Now().UnixNano() < cached.expiresAt {
+ return cached.enabled
+ }
+ }
+
+ result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) {
+ if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
+ if time.Now().UnixNano() < cached.expiresAt {
+ return cached.enabled, nil
+ }
+ }
+
+ enabled := false
+ if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil {
+ dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout)
+ defer cancel()
+
+ value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey)
+ if err == nil {
+ enabled = strings.EqualFold(strings.TrimSpace(value), "true")
+ }
+ }
+
+ openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
+ enabled: enabled,
+ expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
+ })
+ return enabled, nil
+ })
+
+ enabled, _ := result.(bool)
+ return enabled
+}
+
+func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler {
if s == nil {
return nil
}
+ if !s.isOpenAIAdvancedSchedulerEnabled(ctx) {
+ return nil
+ }
s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
@@ -820,6 +1003,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule
return s.openaiScheduler
}
+func resetOpenAIAdvancedSchedulerSettingCacheForTest() {
+ openAIAdvancedSchedulerSettingCache = atomic.Value{}
+ openAIAdvancedSchedulerSettingSF = singleflight.Group{}
+}
+
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context,
groupID *int64,
@@ -828,13 +1016,94 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requestedModel string,
excludedIDs map[int64]struct{},
requiredTransport OpenAIUpstreamTransport,
+ requireCompact bool,
+) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
+ return s.selectAccountWithScheduler(ctx, groupID, previousResponseID, sessionHash, requestedModel, excludedIDs, requiredTransport, "", requireCompact)
+}
+
+func (s *OpenAIGatewayService) SelectAccountWithSchedulerForImages(
+ ctx context.Context,
+ groupID *int64,
+ sessionHash string,
+ requestedModel string,
+ excludedIDs map[int64]struct{},
+ requiredCapability OpenAIImagesCapability,
+) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
+ selection, decision, err := s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, requiredCapability, false)
+ if err == nil && selection != nil && selection.Account != nil {
+ return selection, decision, nil
+ }
+ // 如果要求 native 能力(如指定了模型)但没有可用的 APIKey 账号,回退到 basic(OAuth 账号)
+ if requiredCapability == OpenAIImagesCapabilityNative {
+ return s.selectAccountWithScheduler(ctx, groupID, "", sessionHash, requestedModel, excludedIDs, OpenAIUpstreamTransportHTTPSSE, OpenAIImagesCapabilityBasic, false)
+ }
+ return selection, decision, err
+}
+
+func (s *OpenAIGatewayService) selectAccountWithScheduler(
+ ctx context.Context,
+ groupID *int64,
+ previousResponseID string,
+ sessionHash string,
+ requestedModel string,
+ excludedIDs map[int64]struct{},
+ requiredTransport OpenAIUpstreamTransport,
+ requiredImageCapability OpenAIImagesCapability,
+ requireCompact bool,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil {
- selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
- return selection, decision, err
+ if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
+ effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
+ for {
+ selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
+ if err != nil {
+ return nil, decision, err
+ }
+ if selection == nil || selection.Account == nil {
+ return selection, decision, nil
+ }
+ if selection.Account.SupportsOpenAIImageCapability(requiredImageCapability) {
+ return selection, decision, nil
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ if effectiveExcludedIDs == nil {
+ effectiveExcludedIDs = make(map[int64]struct{})
+ }
+ if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists {
+ return nil, decision, ErrNoAvailableAccounts
+ }
+ effectiveExcludedIDs[selection.Account.ID] = struct{}{}
+ }
+ }
+
+ effectiveExcludedIDs := cloneExcludedAccountIDs(excludedIDs)
+ for {
+ selection, err := s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, effectiveExcludedIDs, requireCompact)
+ if err != nil {
+ return nil, decision, err
+ }
+ if selection == nil || selection.Account == nil {
+ return selection, decision, nil
+ }
+ if s.isOpenAIAccountTransportCompatible(selection.Account, requiredTransport) {
+ return selection, decision, nil
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ if effectiveExcludedIDs == nil {
+ effectiveExcludedIDs = make(map[int64]struct{})
+ }
+ if _, exists := effectiveExcludedIDs[selection.Account.ID]; exists {
+ return nil, decision, ErrNoAvailableAccounts
+ }
+ effectiveExcludedIDs[selection.Account.ID] = struct{}{}
+ }
}
var stickyAccountID int64
@@ -845,18 +1114,41 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
}
return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
- GroupID: groupID,
- SessionHash: sessionHash,
- StickyAccountID: stickyAccountID,
- PreviousResponseID: previousResponseID,
- RequestedModel: requestedModel,
- RequiredTransport: requiredTransport,
- ExcludedIDs: excludedIDs,
+ GroupID: groupID,
+ SessionHash: sessionHash,
+ StickyAccountID: stickyAccountID,
+ PreviousResponseID: previousResponseID,
+ RequestedModel: requestedModel,
+ RequiredTransport: requiredTransport,
+ RequiredImageCapability: requiredImageCapability,
+ RequireCompact: requireCompact,
+ ExcludedIDs: excludedIDs,
})
}
+func cloneExcludedAccountIDs(excludedIDs map[int64]struct{}) map[int64]struct{} {
+ if len(excludedIDs) == 0 {
+ return nil
+ }
+ cloned := make(map[int64]struct{}, len(excludedIDs))
+ for id := range excludedIDs {
+ cloned[id] = struct{}{}
+ }
+ return cloned
+}
+
+func (s *OpenAIGatewayService) isOpenAIAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
+ if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
+ return true
+ }
+ if s == nil || account == nil {
+ return false
+ }
+ return s.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
+}
+
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
@@ -864,7 +1156,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
@@ -872,7 +1164,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
}
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
- scheduler := s.getOpenAIAccountScheduler()
+ scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}
diff --git a/backend/internal/service/openai_account_scheduler_compact_test.go b/backend/internal/service/openai_account_scheduler_compact_test.go
new file mode 100644
index 00000000..f7e08a20
--- /dev/null
+++ b/backend/internal/service/openai_account_scheduler_compact_test.go
@@ -0,0 +1,195 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactPrefersSupportedOverUnknown
+// 验证 compact 调度时显式支持 (tier=2) 优先于未探测 (tier=1)。
+func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactPrefersSupportedOverUnknown(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(91001)
+ accounts := []Account{
+ {
+ ID: 71001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ Extra: map[string]any{}, // unknown
+ },
+ {
+ ID: 71002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ Extra: map[string]any{"openai_compact_supported": true}, // tier=2
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.4",
+ nil,
+ OpenAIUpstreamTransportAny,
+ true,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(71002), selection.Account.ID, "compact-supported account should win over unknown")
+}
+
+// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactRejectsExplicitlyUnsupported
+// 验证 force_off / 已探测不支持 (tier=0) 的账号不会被 compact 请求选中。
+func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactRejectsExplicitlyUnsupported(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(91002)
+ accounts := []Account{
+ {
+ ID: 71010,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff},
+ },
+ {
+ ID: 71011,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ Extra: map[string]any{"openai_compact_supported": false},
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.4",
+ nil,
+ OpenAIUpstreamTransportAny,
+ true,
+ )
+ require.Error(t, err)
+ require.True(t, errors.Is(err, ErrNoAvailableCompactAccounts), "compact-only accounts should rejected explicitly unsupported and return compact error")
+ require.Nil(t, selection)
+}
+
+// TestOpenAIGatewayService_SelectAccountWithScheduler_CompactFallsBackToUnknown
+// 验证当没有"已知支持"账号时,compact 请求会回退到"未探测"账号。
+func TestOpenAIGatewayService_SelectAccountWithScheduler_CompactFallsBackToUnknown(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(91003)
+ accounts := []Account{
+ {
+ ID: 71020,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ Extra: map[string]any{"openai_compact_supported": false}, // tier=0
+ },
+ {
+ ID: 71021,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ Extra: map[string]any{}, // unknown -> tier=1
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.4",
+ nil,
+ OpenAIUpstreamTransportAny,
+ true,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(71021), selection.Account.ID, "unknown account should be picked when no supported account available")
+}
+
+// TestOpenAICompactSupportTier 验证 tier 分类逻辑。
+func TestOpenAICompactSupportTier(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ want int
+ }{
+ {name: "nil", account: nil, want: 0},
+ {name: "non openai", account: &Account{Platform: PlatformAnthropic}, want: 0},
+ {name: "openai unknown", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{}}, want: 1},
+ {name: "openai supported", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_supported": true}}, want: 2},
+ {name: "openai unsupported", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_supported": false}}, want: 0},
+ {name: "force on", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn}}, want: 2},
+ {name: "force off overrides probe true", account: &Account{Platform: PlatformOpenAI, Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff, "openai_compact_supported": true}}, want: 0},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := openAICompactSupportTier(tt.account); got != tt.want {
+ t.Fatalf("openAICompactSupportTier(...) = %d, want %d", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go
index 088815ed..0950ee54 100644
--- a/backend/internal/service/openai_account_scheduler_test.go
+++ b/backend/internal/service/openai_account_scheduler_test.go
@@ -2,6 +2,7 @@ package service
import (
"context"
+ "errors"
"fmt"
"math"
"sync"
@@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct {
accountsByID map[int64]*Account
}
+type schedulerTestOpenAIAccountRepo struct {
+ AccountRepository
+ accounts []Account
+}
+
+func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
+ for i := range r.accounts {
+ if r.accounts[i].ID == id {
+ return &r.accounts[i], nil
+ }
+ }
+ return nil, errors.New("account not found")
+}
+
+func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
+ var result []Account
+ for _, acc := range r.accounts {
+ if acc.Platform == platform {
+ result = append(result, acc)
+ }
+ }
+ return result, nil
+}
+
+func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
+ var result []Account
+ for _, acc := range r.accounts {
+ if acc.Platform == platform {
+ result = append(result, acc)
+ }
+ }
+ return result, nil
+}
+
+func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
+ return r.ListSchedulableByPlatform(ctx, platform)
+}
+
+type schedulerTestConcurrencyCache struct {
+ ConcurrencyCache
+ loadBatchErr error
+ loadMap map[int64]*AccountLoadInfo
+ acquireResults map[int64]bool
+ waitCounts map[int64]int
+ skipDefaultLoad bool
+}
+
+func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
+ if c.acquireResults != nil {
+ if result, ok := c.acquireResults[accountID]; ok {
+ return result, nil
+ }
+ }
+ return true, nil
+}
+
+func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
+ return nil
+}
+
+func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
+ if c.loadBatchErr != nil {
+ return nil, c.loadBatchErr
+ }
+ out := make(map[int64]*AccountLoadInfo, len(accounts))
+ if c.skipDefaultLoad && c.loadMap != nil {
+ for _, acc := range accounts {
+ if load, ok := c.loadMap[acc.ID]; ok {
+ out[acc.ID] = load
+ }
+ }
+ return out, nil
+ }
+ for _, acc := range accounts {
+ if c.loadMap != nil {
+ if load, ok := c.loadMap[acc.ID]; ok {
+ out[acc.ID] = load
+ continue
+ }
+ }
+ out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
+ }
+ return out, nil
+}
+
+func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
+ if c.waitCounts != nil {
+ if count, ok := c.waitCounts[accountID]; ok {
+ return count, nil
+ }
+ }
+ return 0, nil
+}
+
+type schedulerTestGatewayCache struct {
+ sessionBindings map[string]int64
+ deletedSessions map[string]int
+}
+
+func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
+ if id, ok := c.sessionBindings[sessionHash]; ok {
+ return id, nil
+ }
+ return 0, errors.New("not found")
+}
+
+func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
+ if c.sessionBindings == nil {
+ c.sessionBindings = make(map[string]int64)
+ }
+ c.sessionBindings[sessionHash] = accountID
+ return nil
+}
+
+func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
+ return nil
+}
+
+func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
+ if c.sessionBindings == nil {
+ return nil
+ }
+ if c.deletedSessions == nil {
+ c.deletedSessions = make(map[string]int)
+ }
+ c.deletedSessions[sessionHash]++
+ delete(c.sessionBindings, sessionHash)
+ return nil
+}
+
+func newSchedulerTestOpenAIWSV2Config() *config.Config {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ return cfg
+}
+
+type openAIAdvancedSchedulerSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ value, err := s.GetValue(ctx, key)
+ if err != nil {
+ return nil, err
+ }
+ return &Setting{Key: key, Value: value}, nil
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if s == nil || s.values == nil {
+ return "", ErrSettingNotFound
+ }
+ value, ok := s.values[key]
+ if !ok {
+ return "", ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected call to Set")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
+ panic("unexpected call to GetMultiple")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected call to SetMultiple")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected call to GetAll")
+}
+
+func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected call to Delete")
+}
+
+func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+ repo := &openAIAdvancedSchedulerSettingRepoStub{
+ values: map[string]string{},
+ }
+ if enabled != "" {
+ repo.values[openAIAdvancedSchedulerSettingKey] = enabled
+ }
+ return &RateLimitService{
+ settingService: NewSettingService(repo, &config.Config{}),
+ }
+}
+
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
if len(s.snapshotAccounts) == 0 {
return nil, false, nil
@@ -45,6 +242,234 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6
return &cloned, nil
}
+func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10106)
+ accounts := []Account{
+ {
+ ID: 36001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ },
+ {
+ ID: 36002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ cache := &schedulerTestGatewayCache{}
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ store := svc.getOpenAIWSStateStore()
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour))
+ require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "resp_disabled_001",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ false,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(36002), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ require.False(t, decision.StickyPreviousHit)
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_SkipsHTTPOnlyAccount(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10108)
+ accounts := []Account{
+ {
+ ID: 36011,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ {
+ ID: 36012,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ },
+ }
+ cfg := newSchedulerTestOpenAIWSV2Config()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportResponsesWebsocketV2,
+ false,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(36012), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabled_RequiredWSV2_NoAvailableAccount(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10109)
+ accounts := []Account{
+ {
+ ID: 36021,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+ cfg := newSchedulerTestOpenAIWSV2Config()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportResponsesWebsocketV2,
+ false,
+ )
+ require.ErrorContains(t, err, "no available OpenAI accounts")
+ require.Nil(t, selection)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ ctx := context.Background()
+ groupID := int64(10107)
+ accounts := []Account{
+ {
+ ID: 37001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ },
+ {
+ ID: 37002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.Scheduling.LoadBatchEnabled = false
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
+
+ store := svc.getOpenAIWSStateStore()
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour))
+ require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "resp_enabled_001",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ false,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(37001), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
+ require.True(t, decision.StickyPreviousHit)
+}
+
+func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
+ svc := &OpenAIGatewayService{}
+ ttft := 120
+ svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
+ svc.RecordOpenAIAccountSwitch()
+
+ snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
+ require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
+}
+
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10101)
@@ -53,12 +478,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
- cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
+ cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
- svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}},
+ cache: cache,
+ cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ schedulerSnapshot: snapshotService,
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
+ }
- selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
+ selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
@@ -76,7 +508,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
- svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
+ svc := &OpenAIGatewayService{
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}},
+ cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ schedulerSnapshot: snapshotService,
+ }
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
@@ -92,21 +529,22 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
- cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
+ cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache,
cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
- selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
+ selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
@@ -128,8 +566,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{},
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
}
@@ -153,7 +592,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
- cache := &stubGatewayCache{}
+ cache := &schedulerTestGatewayCache{}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
@@ -163,10 +602,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
@@ -180,6 +620,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
+ false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -204,17 +645,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
Schedulable: true,
Concurrency: 1,
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_abc": account.ID,
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -225,6 +667,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
+ false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -260,7 +703,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
Priority: 9,
},
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_sticky_busy": 21001,
},
@@ -273,7 +716,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
acquireResults: map[int64]bool{
21001: false, // sticky 账号已满
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
@@ -288,9 +731,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -302,6 +746,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
+ false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -328,17 +773,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"openai_ws_force_http": true,
},
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_force_http": account.ID,
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -349,6 +795,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
+ false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -387,15 +834,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
},
},
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_ws_only": 2201,
},
}
- cfg := newOpenAIWSV2TestConfig()
+ cfg := newSchedulerTestOpenAIWSV2Config()
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
@@ -403,9 +850,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -417,6 +865,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
+ false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -445,10 +894,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
- cache: &stubGatewayCache{},
- cfg: newOpenAIWSV2TestConfig(),
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
+ cfg: newSchedulerTestOpenAIWSV2Config(),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -459,6 +909,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
+ false,
)
require.Error(t, err)
require.Nil(t, selection)
@@ -507,7 +958,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
@@ -520,9 +971,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
- cache: &stubGatewayCache{},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{},
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -534,6 +986,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
+ false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -559,19 +1012,20 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
Schedulable: true,
Concurrency: 1,
}
- cache := &stubGatewayCache{
+ cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_metrics": account.ID,
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
- selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
+ selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny, false)
require.NoError(t, err)
require.NotNil(t, selection)
svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
@@ -749,7 +1203,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
- concurrencyCache := stubConcurrencyCache{
+ concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
@@ -757,9 +1211,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
},
}
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: accounts},
- cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
+ cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}},
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -774,6 +1229,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
+ false,
)
require.NoError(t, err)
require.NotNil(t, selection)
@@ -905,12 +1361,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
}
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
+ resetOpenAIAdvancedSchedulerSettingCacheForTest()
+
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
- require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
+ require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
require.Equal(t, 7, svc.openAIWSLBTopK())
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
@@ -947,7 +1405,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
- cfg := newOpenAIWSV2TestConfig()
+ cfg := newSchedulerTestOpenAIWSV2Config()
scheduler.service = &OpenAIGatewayService{cfg: cfg}
account := &Account{
ID: 8801,
diff --git a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
index c5de8203..8d63e68e 100644
--- a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
+++ b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go
@@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
svc := &OpenAIGatewayService{
- accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}},
- cache: &stubGatewayCache{},
+ accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}},
+ cache: &schedulerTestGatewayCache{},
cfg: cfg,
+ rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
- concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -53,6 +54,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
"gpt-5.1",
nil,
OpenAIUpstreamTransportResponsesWebsocketV2,
+ false,
)
require.NoError(t, err)
require.NotNil(t, selection)
diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go
index a266d6a0..b256f1c7 100644
--- a/backend/internal/service/openai_codex_transform.go
+++ b/backend/internal/service/openai_codex_transform.go
@@ -1,14 +1,15 @@
package service
import (
+ "encoding/json"
"fmt"
"strings"
)
var codexModelMap = map[string]string{
+ "gpt-5.5": "gpt-5.5",
"gpt-5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini",
- "gpt-5.4-nano": "gpt-5.4-nano",
"gpt-5.4-none": "gpt-5.4",
"gpt-5.4-low": "gpt-5.4",
"gpt-5.4-medium": "gpt-5.4",
@@ -22,52 +23,21 @@ var codexModelMap = map[string]string{
"gpt-5.3-high": "gpt-5.3-codex",
"gpt-5.3-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex",
- "gpt-5.3-codex-spark": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-low": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-medium": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-high": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
+ "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-low": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-medium": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt-5.3-codex-low": "gpt-5.3-codex",
"gpt-5.3-codex-medium": "gpt-5.3-codex",
"gpt-5.3-codex-high": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
- "gpt-5.1-codex": "gpt-5.1-codex",
- "gpt-5.1-codex-low": "gpt-5.1-codex",
- "gpt-5.1-codex-medium": "gpt-5.1-codex",
- "gpt-5.1-codex-high": "gpt-5.1-codex",
- "gpt-5.1-codex-max": "gpt-5.1-codex-max",
- "gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
- "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
- "gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
- "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
"gpt-5.2": "gpt-5.2",
"gpt-5.2-none": "gpt-5.2",
"gpt-5.2-low": "gpt-5.2",
"gpt-5.2-medium": "gpt-5.2",
"gpt-5.2-high": "gpt-5.2",
"gpt-5.2-xhigh": "gpt-5.2",
- "gpt-5.2-codex": "gpt-5.2-codex",
- "gpt-5.2-codex-low": "gpt-5.2-codex",
- "gpt-5.2-codex-medium": "gpt-5.2-codex",
- "gpt-5.2-codex-high": "gpt-5.2-codex",
- "gpt-5.2-codex-xhigh": "gpt-5.2-codex",
- "gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
- "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
- "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
- "gpt-5.1": "gpt-5.1",
- "gpt-5.1-none": "gpt-5.1",
- "gpt-5.1-low": "gpt-5.1",
- "gpt-5.1-medium": "gpt-5.1",
- "gpt-5.1-high": "gpt-5.1",
- "gpt-5.1-chat-latest": "gpt-5.1",
- "gpt-5-codex": "gpt-5.1-codex",
- "codex-mini-latest": "gpt-5.1-codex-mini",
- "gpt-5-codex-mini": "gpt-5.1-codex-mini",
- "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
- "gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
- "gpt-5": "gpt-5.1",
- "gpt-5-mini": "gpt-5.1",
- "gpt-5-nano": "gpt-5.1",
}
type codexTransformResult struct {
@@ -76,6 +46,30 @@ type codexTransformResult struct {
PromptCacheKey string
}
+const (
+ codexImageGenerationBridgeMarker = ""
+ codexImageGenerationBridgeText = codexImageGenerationBridgeMarker + "\nWhen the user asks for raster image generation or editing, use the OpenAI Responses native `image_generation` tool attached to this request. The local Codex client may not expose an `image_gen` namespace, but that does not mean image generation is unavailable. Do not ask the user to switch to CLI fallback solely because `image_gen` is absent.\n "
+ codexSparkImageUnsupportedMarker = ""
+ codexSparkImageUnsupportedText = codexSparkImageUnsupportedMarker + "\nThe current model is gpt-5.3-codex-spark, which does not support image generation, image editing, image input, the `image_generation` tool, or Codex `image_gen`/`$imagegen` workflows. If the user asks for image generation or image editing, clearly explain this model limitation and ask them to switch to a non-Spark Codex model such as gpt-5.3-codex or gpt-5.4. Do not claim that the local environment merely lacks image_gen tooling, and do not suggest CLI fallback as the primary fix while the model remains Spark.\n "
+)
+
+var openAIChatGPTInternalUnsupportedFields = []string{
+ "user",
+ "metadata",
+ "prompt_cache_retention",
+ "safety_identifier",
+ "stream_options",
+}
+
+var openAICodexOAuthUnsupportedFields = append([]string{
+ "max_output_tokens",
+ "max_completion_tokens",
+ "temperature",
+ "top_p",
+ "frequency_penalty",
+ "presence_penalty",
+}, openAIChatGPTInternalUnsupportedFields...)
+
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact bool) codexTransformResult {
result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
@@ -116,23 +110,8 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
}
}
- // Strip parameters unsupported by codex models via the Responses API.
- for _, key := range []string{
- "max_output_tokens",
- "max_completion_tokens",
- "temperature",
- "top_p",
- "frequency_penalty",
- "presence_penalty",
- // prompt_cache_retention is a newer Responses API parameter (cache TTL).
- // The ChatGPT internal Codex endpoint rejects it with
- // "Unsupported parameter: prompt_cache_retention". Defense-in-depth
- // for any OAuth path that reaches this transform — the Cursor
- // Responses-shape short-circuit in ForwardAsChatCompletions strips
- // it earlier too, but we keep this line so other OAuth callers are
- // equally protected.
- "prompt_cache_retention",
- } {
+ // Strip parameters unsupported by ChatGPT internal Codex endpoint.
+ for _, key := range openAICodexOAuthUnsupportedFields {
if _, ok := reqBody[key]; ok {
delete(reqBody, key)
result.Modified = true
@@ -164,9 +143,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if name, ok := fcObj["name"].(string); ok && strings.TrimSpace(name) != "" {
reqBody["tool_choice"] = map[string]any{
"type": "function",
- "function": map[string]any{
- "name": name,
- },
+ "name": name,
}
}
}
@@ -177,6 +154,9 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if normalizeCodexTools(reqBody) {
result.Modified = true
}
+ if normalizeCodexToolChoice(reqBody) {
+ result.Modified = true
+ }
if v, ok := reqBody["prompt_cache_key"].(string); ok {
result.PromptCacheKey = strings.TrimSpace(v)
@@ -191,9 +171,20 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
if applyInstructions(reqBody, isCodexCLI) {
result.Modified = true
}
+ if isCodexSparkModel(normalizedModel) && applyCodexSparkImageUnsupportedInstructions(reqBody) {
+ result.Modified = true
+ }
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
if input, ok := reqBody["input"].([]any); ok {
+ if normalizedInput, modified := normalizeCodexToolRoleMessages(input); modified {
+ input = normalizedInput
+ result.Modified = true
+ }
+ if normalizedInput, modified := normalizeCodexMessageContentText(input); modified {
+ input = normalizedInput
+ result.Modified = true
+ }
input = filterCodexInput(input, needsToolContinuation)
reqBody["input"] = input
result.Modified = true
@@ -218,9 +209,246 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
return result
}
+func normalizeCodexToolChoice(reqBody map[string]any) bool {
+ choice, ok := reqBody["tool_choice"]
+ if !ok || choice == nil {
+ return false
+ }
+ choiceMap, ok := choice.(map[string]any)
+ if !ok {
+ return false
+ }
+ choiceType := strings.TrimSpace(firstNonEmptyString(choiceMap["type"]))
+ if choiceType == "" {
+ return false
+ }
+ modified := false
+ if choiceType == "function" {
+ name := strings.TrimSpace(firstNonEmptyString(choiceMap["name"]))
+ if name == "" {
+ if function, ok := choiceMap["function"].(map[string]any); ok {
+ name = strings.TrimSpace(firstNonEmptyString(function["name"]))
+ }
+ }
+ if name == "" {
+ reqBody["tool_choice"] = "auto"
+ return true
+ }
+ if strings.TrimSpace(firstNonEmptyString(choiceMap["name"])) != name {
+ choiceMap["name"] = name
+ modified = true
+ }
+ if _, ok := choiceMap["function"]; ok {
+ delete(choiceMap, "function")
+ modified = true
+ }
+ if !codexToolsContainFunctionName(reqBody["tools"], name) {
+ reqBody["tool_choice"] = "auto"
+ return true
+ }
+ return modified
+ }
+ if codexToolsContainType(reqBody["tools"], choiceType) {
+ return modified
+ }
+ reqBody["tool_choice"] = "auto"
+ return true
+}
+
+func codexToolsContainType(rawTools any, toolType string) bool {
+ tools, ok := rawTools.([]any)
+ if !ok || strings.TrimSpace(toolType) == "" {
+ return false
+ }
+ for _, rawTool := range tools {
+ tool, ok := rawTool.(map[string]any)
+ if !ok {
+ continue
+ }
+ if strings.TrimSpace(firstNonEmptyString(tool["type"])) == toolType {
+ return true
+ }
+ }
+ return false
+}
+
+func codexToolsContainFunctionName(rawTools any, name string) bool {
+ tools, ok := rawTools.([]any)
+ if !ok || strings.TrimSpace(name) == "" {
+ return false
+ }
+ normalizedName := strings.TrimSpace(name)
+ for _, rawTool := range tools {
+ tool, ok := rawTool.(map[string]any)
+ if !ok {
+ continue
+ }
+ if strings.TrimSpace(firstNonEmptyString(tool["type"])) != "function" {
+ continue
+ }
+ toolName := strings.TrimSpace(firstNonEmptyString(tool["name"]))
+ if toolName == "" {
+ if function, ok := tool["function"].(map[string]any); ok {
+ toolName = strings.TrimSpace(firstNonEmptyString(function["name"]))
+ }
+ }
+ if toolName == normalizedName {
+ return true
+ }
+ }
+ return false
+}
+
+func normalizeCodexToolRoleMessages(input []any) ([]any, bool) {
+ if len(input) == 0 {
+ return input, false
+ }
+
+ modified := false
+ normalized := make([]any, 0, len(input))
+ for _, item := range input {
+ m, ok := item.(map[string]any)
+ if !ok {
+ normalized = append(normalized, item)
+ continue
+ }
+ role, _ := m["role"].(string)
+ if strings.TrimSpace(role) != "tool" {
+ normalized = append(normalized, item)
+ continue
+ }
+
+ callID := firstNonEmptyString(m["call_id"], m["tool_call_id"], m["id"])
+ callID = strings.TrimSpace(callID)
+ if callID == "" {
+ // Responses does not accept role:"tool". If no call id is available,
+ // preserve the text as a user message instead of sending invalid input.
+ fallback := make(map[string]any, len(m))
+ for key, value := range m {
+ fallback[key] = value
+ }
+ fallback["role"] = "user"
+ delete(fallback, "tool_call_id")
+ normalized = append(normalized, fallback)
+ modified = true
+ continue
+ }
+
+ output := extractTextFromContent(m["content"])
+ if output == "" {
+ if value, ok := m["output"].(string); ok {
+ output = value
+ }
+ }
+ if output == "" && m["content"] != nil {
+ if b, err := json.Marshal(m["content"]); err == nil {
+ output = string(b)
+ }
+ }
+
+ normalized = append(normalized, map[string]any{
+ "type": "function_call_output",
+ "call_id": callID,
+ "output": output,
+ })
+ modified = true
+ }
+ if !modified {
+ return input, false
+ }
+ return normalized, true
+}
+
+func normalizeCodexMessageContentText(input []any) ([]any, bool) {
+ if len(input) == 0 {
+ return input, false
+ }
+
+ modified := false
+ normalized := make([]any, 0, len(input))
+ for _, item := range input {
+ m, ok := item.(map[string]any)
+ if !ok || strings.TrimSpace(firstNonEmptyString(m["type"])) != "message" {
+ normalized = append(normalized, item)
+ continue
+ }
+ parts, ok := m["content"].([]any)
+ if !ok {
+ normalized = append(normalized, item)
+ continue
+ }
+
+ var newItem map[string]any
+ var newParts []any
+ ensureItemCopy := func() {
+ if newItem != nil {
+ return
+ }
+ newItem = make(map[string]any, len(m))
+ for key, value := range m {
+ newItem[key] = value
+ }
+ newParts = make([]any, len(parts))
+ copy(newParts, parts)
+ }
+
+ for i, rawPart := range parts {
+ part, ok := rawPart.(map[string]any)
+ if !ok {
+ continue
+ }
+ text, hasText := part["text"]
+ if !hasText {
+ continue
+ }
+ if _, ok := text.(string); ok {
+ continue
+ }
+
+ ensureItemCopy()
+ newPart := make(map[string]any, len(part))
+ for key, value := range part {
+ newPart[key] = value
+ }
+ newPart["text"] = stringifyCodexContentText(text)
+ newParts[i] = newPart
+ modified = true
+ }
+
+ if newItem != nil {
+ newItem["content"] = newParts
+ normalized = append(normalized, newItem)
+ continue
+ }
+ normalized = append(normalized, item)
+ }
+ if !modified {
+ return input, false
+ }
+ return normalized, true
+}
+
+func stringifyCodexContentText(value any) string {
+ switch v := value.(type) {
+ case string:
+ return v
+ case nil:
+ return ""
+ default:
+ if b, err := json.Marshal(v); err == nil {
+ return string(b)
+ }
+ return fmt.Sprint(v)
+ }
+}
+
func normalizeCodexModel(model string) string {
+ model = strings.TrimSpace(model)
if model == "" {
- return "gpt-5.1"
+ return "gpt-5.4"
+ }
+ if isOpenAIImageGenerationModel(model) {
+ return model
}
modelID := model
@@ -235,52 +463,299 @@ func normalizeCodexModel(model string) string {
normalized := strings.ToLower(modelID)
+ if strings.Contains(normalized, "gpt-5.5") || strings.Contains(normalized, "gpt 5.5") {
+ return "gpt-5.5"
+ }
if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") {
return "gpt-5.4-mini"
}
- if strings.Contains(normalized, "gpt-5.4-nano") || strings.Contains(normalized, "gpt 5.4 nano") {
- return "gpt-5.4-nano"
- }
if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") {
return "gpt-5.4"
}
- if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
- return "gpt-5.2-codex"
- }
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
return "gpt-5.2"
}
+ if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") {
+ return "gpt-5.3-codex-spark"
+ }
if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") {
return "gpt-5.3-codex"
}
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
return "gpt-5.3-codex"
}
- if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
- return "gpt-5.1-codex-max"
- }
- if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
- return "gpt-5.1-codex-mini"
- }
- if strings.Contains(normalized, "codex-mini-latest") ||
- strings.Contains(normalized, "gpt-5-codex-mini") ||
- strings.Contains(normalized, "gpt 5 codex mini") {
- return "codex-mini-latest"
- }
- if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
- return "gpt-5.1-codex"
- }
- if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
- return "gpt-5.1"
- }
if strings.Contains(normalized, "codex") {
- return "gpt-5.1-codex"
+ return "gpt-5.3-codex"
}
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
- return "gpt-5.1"
+ return "gpt-5.4"
}
- return "gpt-5.1"
+ return "gpt-5.4"
+}
+
+func isCodexSparkModel(model string) bool {
+ return normalizeCodexModel(model) == "gpt-5.3-codex-spark"
+}
+
+func hasOpenAIImageGenerationTool(reqBody map[string]any) bool {
+ rawTools, ok := reqBody["tools"]
+ if !ok || rawTools == nil {
+ return false
+ }
+ tools, ok := rawTools.([]any)
+ if !ok {
+ return false
+ }
+ for _, rawTool := range tools {
+ toolMap, ok := rawTool.(map[string]any)
+ if !ok {
+ continue
+ }
+ if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" {
+ return true
+ }
+ }
+ return false
+}
+
+func hasOpenAIInputImage(reqBody map[string]any) bool {
+ if reqBody == nil {
+ return false
+ }
+ return hasOpenAIInputImageValue(reqBody["input"]) || hasOpenAIInputImageValue(reqBody["messages"])
+}
+
+func hasOpenAIInputImageValue(value any) bool {
+ switch v := value.(type) {
+ case []any:
+ for _, item := range v {
+ if hasOpenAIInputImageValue(item) {
+ return true
+ }
+ }
+ case map[string]any:
+ if strings.TrimSpace(firstNonEmptyString(v["type"])) == "input_image" {
+ return true
+ }
+ if _, ok := v["image_url"]; ok {
+ return true
+ }
+ return hasOpenAIInputImageValue(v["content"])
+ }
+ return false
+}
+
+func validateCodexSparkInput(reqBody map[string]any, model string) error {
+ if !isCodexSparkModel(model) || !hasOpenAIInputImage(reqBody) {
+ return nil
+ }
+ return fmt.Errorf("model %q does not support image input", strings.TrimSpace(model))
+}
+
+func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool {
+ rawTools, ok := reqBody["tools"]
+ if !ok || rawTools == nil {
+ return false
+ }
+ tools, ok := rawTools.([]any)
+ if !ok {
+ return false
+ }
+
+ modified := false
+ for _, rawTool := range tools {
+ toolMap, ok := rawTool.(map[string]any)
+ if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" {
+ continue
+ }
+ if _, ok := toolMap["output_format"]; !ok {
+ if value := strings.TrimSpace(firstNonEmptyString(toolMap["format"])); value != "" {
+ toolMap["output_format"] = value
+ modified = true
+ }
+ }
+ if _, ok := toolMap["output_compression"]; !ok {
+ if value, exists := toolMap["compression"]; exists && value != nil {
+ toolMap["output_compression"] = value
+ modified = true
+ }
+ }
+ if _, ok := toolMap["format"]; ok {
+ delete(toolMap, "format")
+ modified = true
+ }
+ if _, ok := toolMap["compression"]; ok {
+ delete(toolMap, "compression")
+ modified = true
+ }
+ }
+ return modified
+}
+
+func ensureOpenAIResponsesImageGenerationTool(reqBody map[string]any) bool {
+ if len(reqBody) == 0 {
+ return false
+ }
+ if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) {
+ return false
+ }
+
+ tool := map[string]any{
+ "type": "image_generation",
+ "output_format": "png",
+ }
+
+ rawTools, ok := reqBody["tools"]
+ if !ok || rawTools == nil {
+ reqBody["tools"] = []any{tool}
+ return true
+ }
+
+ tools, ok := rawTools.([]any)
+ if !ok {
+ reqBody["tools"] = []any{tool}
+ return true
+ }
+ for _, rawTool := range tools {
+ toolMap, ok := rawTool.(map[string]any)
+ if !ok {
+ continue
+ }
+ if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" {
+ return false
+ }
+ }
+
+ reqBody["tools"] = append(tools, tool)
+ return true
+}
+
+func applyCodexImageGenerationBridgeInstructions(reqBody map[string]any) bool {
+ if len(reqBody) == 0 || !hasOpenAIImageGenerationTool(reqBody) {
+ return false
+ }
+ if isCodexSparkModel(firstNonEmptyString(reqBody["model"])) {
+ return false
+ }
+
+ existing, _ := reqBody["instructions"].(string)
+ if strings.Contains(existing, codexImageGenerationBridgeMarker) {
+ return false
+ }
+
+ existing = strings.TrimRight(existing, " \t\r\n")
+ if strings.TrimSpace(existing) == "" {
+ reqBody["instructions"] = codexImageGenerationBridgeText
+ return true
+ }
+
+ reqBody["instructions"] = existing + "\n\n" + codexImageGenerationBridgeText
+ return true
+}
+
+func applyCodexSparkImageUnsupportedInstructions(reqBody map[string]any) bool {
+ if len(reqBody) == 0 {
+ return false
+ }
+ existing, _ := reqBody["instructions"].(string)
+ if strings.Contains(existing, codexSparkImageUnsupportedMarker) {
+ return false
+ }
+ existing = strings.TrimRight(existing, " \t\r\n")
+ if strings.TrimSpace(existing) == "" {
+ reqBody["instructions"] = codexSparkImageUnsupportedText
+ return true
+ }
+ reqBody["instructions"] = existing + "\n\n" + codexSparkImageUnsupportedText
+ return true
+}
+
+func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error {
+ if !hasOpenAIImageGenerationTool(reqBody) {
+ return nil
+ }
+ model = strings.TrimSpace(model)
+ if !isOpenAIImageGenerationModel(model) {
+ return nil
+ }
+ return fmt.Errorf("/v1/responses image_generation requests require a Responses-capable text model; image-only model %q is not allowed", model)
+}
+
+func normalizeOpenAIResponsesImageOnlyModel(reqBody map[string]any) bool {
+ if len(reqBody) == 0 {
+ return false
+ }
+ imageModel := strings.TrimSpace(firstNonEmptyString(reqBody["model"]))
+ if !isOpenAIImageGenerationModel(imageModel) {
+ return false
+ }
+
+ modified := false
+ tools, _ := reqBody["tools"].([]any)
+ imageToolIndex := -1
+ for i, rawTool := range tools {
+ toolMap, ok := rawTool.(map[string]any)
+ if !ok {
+ continue
+ }
+ if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" {
+ imageToolIndex = i
+ break
+ }
+ }
+ if imageToolIndex < 0 {
+ tools = append(tools, map[string]any{
+ "type": "image_generation",
+ "model": imageModel,
+ })
+ imageToolIndex = len(tools) - 1
+ reqBody["tools"] = tools
+ modified = true
+ }
+
+ if toolMap, ok := tools[imageToolIndex].(map[string]any); ok {
+ if strings.TrimSpace(firstNonEmptyString(toolMap["model"])) == "" {
+ toolMap["model"] = imageModel
+ modified = true
+ }
+ for _, key := range []string{
+ "size",
+ "quality",
+ "background",
+ "output_format",
+ "output_compression",
+ "moderation",
+ "style",
+ "partial_images",
+ } {
+ if value, exists := reqBody[key]; exists && value != nil {
+ if _, toolHas := toolMap[key]; !toolHas {
+ toolMap[key] = value
+ }
+ delete(reqBody, key)
+ modified = true
+ }
+ }
+ }
+
+ if prompt := strings.TrimSpace(firstNonEmptyString(reqBody["prompt"])); prompt != "" {
+ if _, hasInput := reqBody["input"]; !hasInput {
+ reqBody["input"] = prompt
+ }
+ delete(reqBody, "prompt")
+ modified = true
+ }
+
+ if _, ok := reqBody["tool_choice"]; !ok {
+ reqBody["tool_choice"] = map[string]any{"type": "image_generation"}
+ modified = true
+ }
+ if imageModel != openAIImagesResponsesMainModel {
+ modified = true
+ }
+ reqBody["model"] = openAIImagesResponsesMainModel
+ return modified
}
func normalizeOpenAIModelForUpstream(account *Account, model string) string {
@@ -434,6 +909,14 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
typ, _ := m["type"].(string)
+ // chatgpt.com codex backend (OAuth path) does not persist reasoning
+ // items because applyCodexOAuthTransform forces store=false. Any rs_*
+ // reference replayed in input is guaranteed to 404 upstream
+ // ("Item with id 'rs_...' not found"). Drop reasoning items entirely.
+ if typ == "reasoning" {
+ continue
+ }
+
// 仅修正真正的 tool/function call 标识,避免误改普通 message/reasoning id;
// 若 item_reference 指向 legacy call_* 标识,则仅修正该引用本身。
fixCallIDPrefix := func(id string) string {
@@ -494,12 +977,30 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
}
+ if !isCodexToolCallItemType(typ) {
+ ensureCopy()
+ delete(newItem, "call_id")
+ }
+
+ if codexInputItemRequiresName(typ) {
+ if strings.TrimSpace(firstNonEmptyString(m["name"])) == "" {
+ name := firstNonEmptyString(m["tool_name"])
+ if name == "" {
+ if function, ok := m["function"].(map[string]any); ok {
+ name = firstNonEmptyString(function["name"])
+ }
+ }
+ if name == "" {
+ name = "tool"
+ }
+ ensureCopy()
+ newItem["name"] = name
+ }
+ }
+
if !preserveReferences {
ensureCopy()
delete(newItem, "id")
- if !isCodexToolCallItemType(typ) {
- delete(newItem, "call_id")
- }
}
filtered = append(filtered, newItem)
@@ -508,10 +1009,30 @@ func filterCodexInput(input []any, preserveReferences bool) []any {
}
func isCodexToolCallItemType(typ string) bool {
- if typ == "" {
+ switch typ {
+ case "function_call",
+ "tool_call",
+ "local_shell_call",
+ "tool_search_call",
+ "custom_tool_call",
+ "mcp_tool_call",
+ "function_call_output",
+ "mcp_tool_call_output",
+ "custom_tool_call_output",
+ "tool_search_output":
+ return true
+ default:
+ return false
+ }
+}
+
+func codexInputItemRequiresName(typ string) bool {
+ switch strings.TrimSpace(typ) {
+ case "function_call", "custom_tool_call", "mcp_tool_call":
+ return true
+ default:
return false
}
- return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output")
}
func normalizeCodexTools(reqBody map[string]any) bool {
diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go
index 993ade07..87bb7162 100644
--- a/backend/internal/service/openai_codex_transform_test.go
+++ b/backend/internal/service/openai_codex_transform_test.go
@@ -1,6 +1,8 @@
package service
import (
+ "fmt"
+ "strings"
"testing"
"github.com/stretchr/testify/require"
@@ -92,6 +94,273 @@ func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly
require.Equal(t, "fc1", second["call_id"])
}
+func TestApplyCodexOAuthTransform_ToolSearchOutputPreservesCallID(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.2",
+ "input": []any{
+ map[string]any{"type": "tool_search_output", "call_id": "call_1", "output": "ok"},
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, false, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 1)
+
+ first, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "tool_search_output", first["type"])
+ require.Equal(t, "fc1", first["call_id"])
+}
+
+func TestApplyCodexOAuthTransform_CustomAndMCPToolOutputsPreserveCallID(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.2",
+ "input": []any{
+ map[string]any{"type": "custom_tool_call_output", "call_id": "call_custom", "output": "ok"},
+ map[string]any{"type": "mcp_tool_call_output", "call_id": "call_mcp", "output": "ok"},
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, false, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 2)
+
+ first, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "fccustom", first["call_id"])
+
+ second, ok := input[1].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "fcmcp", second["call_id"])
+}
+
+func TestApplyCodexOAuthTransform_ImageAndWebSearchCallsDoNotGainCallID(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.2",
+ "input": []any{
+ map[string]any{"type": "image_generation_call", "id": "ig_123", "status": "completed"},
+ map[string]any{"type": "web_search_call", "call_id": "call_bad", "status": "completed"},
+ },
+ "tool_choice": "auto",
+ }
+
+ applyCodexOAuthTransform(reqBody, false, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 2)
+
+ first, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "ig_123", first["id"])
+ _, hasCallID := first["call_id"]
+ require.False(t, hasCallID)
+
+ second, ok := input[1].(map[string]any)
+ require.True(t, ok)
+ _, hasCallID = second["call_id"]
+ require.False(t, hasCallID)
+}
+
+func TestApplyCodexOAuthTransform_ConvertsToolRoleMessageToFunctionCallOutput(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "input": []any{
+ map[string]any{
+ "type": "message",
+ "role": "tool",
+ "tool_call_id": "call_1",
+ "content": "ok",
+ },
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 1)
+
+ item, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "function_call_output", item["type"])
+ require.Equal(t, "fc1", item["call_id"])
+ require.Equal(t, "ok", item["output"])
+ _, hasRole := item["role"]
+ require.False(t, hasRole)
+}
+
+func TestApplyCodexOAuthTransform_StringifiesNonStringMessageContentText(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "input": []any{
+ map[string]any{
+ "type": "message",
+ "role": "user",
+ "content": []any{
+ map[string]any{"type": "input_text", "text": []any{"a", "b"}},
+ },
+ },
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ item, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ content, ok := item["content"].([]any)
+ require.True(t, ok)
+ part, ok := content[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, `["a","b"]`, part["text"])
+}
+
+func TestApplyCodexOAuthTransform_DowngradesUnknownToolChoice(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "tools": []any{
+ map[string]any{"type": "function", "name": "shell"},
+ },
+ "tool_choice": map[string]any{"type": "custom"},
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ require.Equal(t, "auto", reqBody["tool_choice"])
+}
+
+func TestApplyCodexOAuthTransform_PreservesKnownToolChoice(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "tools": []any{
+ map[string]any{"type": "custom", "name": "shell"},
+ },
+ "tool_choice": map[string]any{"type": "custom"},
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ choice, ok := reqBody["tool_choice"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "custom", choice["type"])
+}
+
+func TestApplyCodexOAuthTransform_NormalizesLegacyFunctionToolChoice(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "tools": []any{
+ map[string]any{"type": "function", "name": "shell"},
+ },
+ "tool_choice": map[string]any{
+ "type": "function",
+ "function": map[string]any{"name": "shell"},
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ choice, ok := reqBody["tool_choice"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "function", choice["type"])
+ require.Equal(t, "shell", choice["name"])
+ require.NotContains(t, choice, "function")
+}
+
+func TestApplyCodexOAuthTransform_DowngradesMissingFunctionToolChoice(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "tools": []any{
+ map[string]any{"type": "function", "name": "shell"},
+ },
+ "tool_choice": map[string]any{
+ "type": "function",
+ "function": map[string]any{"name": "missing"},
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ require.Equal(t, "auto", reqBody["tool_choice"])
+}
+
+func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "input": []any{
+ map[string]any{"type": "message", "role": "user", "content": "run tool"},
+ map[string]any{"type": "function_call", "call_id": "call_1", "arguments": "{}"},
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 2)
+ item, ok := input[1].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "function_call", item["type"])
+ require.Equal(t, "tool", item["name"])
+ require.Equal(t, "fc1", item["call_id"])
+}
+
+func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "input": []any{
+ map[string]any{"type": "custom_tool_call", "call_id": "call_1", "name": "shell", "input": "pwd"},
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 1)
+ item, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "shell", item["name"])
+ require.Equal(t, "fc1", item["call_id"])
+}
+
+func TestApplyCodexOAuthTransform_PreservesMCPToolCallIDAndName(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "input": []any{
+ map[string]any{
+ "type": "mcp_tool_call",
+ "call_id": "call_abc",
+ "name": "remote_tool",
+ "arguments": "{}",
+ },
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+
+ input, ok := reqBody["input"].([]any)
+ require.True(t, ok)
+ require.Len(t, input, 1)
+ item, ok := input[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "mcp_tool_call", item["type"])
+ require.Equal(t, "remote_tool", item["name"])
+ require.Equal(t, "fcabc", item["call_id"])
+}
+
+func TestCodexInputItemRequiresNameTypesAllowCallID(t *testing.T) {
+ for _, typ := range []string{"function_call", "custom_tool_call", "mcp_tool_call"} {
+ require.True(t, codexInputItemRequiresName(typ), typ)
+ require.True(t, isCodexToolCallItemType(typ), typ)
+ }
+}
+
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true,保持 false。
@@ -217,6 +486,306 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
require.Equal(t, "bash", first["name"])
}
+func TestNormalizeOpenAIResponsesImageGenerationTools_RewritesLegacyFields(t *testing.T) {
+ reqBody := map[string]any{
+ "tools": []any{
+ map[string]any{
+ "type": "image_generation",
+ "format": "png",
+ "compression": 60,
+ },
+ },
+ }
+
+ modified := normalizeOpenAIResponsesImageGenerationTools(reqBody)
+ require.True(t, modified)
+
+ tools, ok := reqBody["tools"].([]any)
+ require.True(t, ok)
+ first, ok := tools[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "png", first["output_format"])
+ require.Equal(t, 60, first["output_compression"])
+ _, hasFormat := first["format"]
+ require.False(t, hasFormat)
+ _, hasCompression := first["compression"]
+ require.False(t, hasCompression)
+}
+
+func TestEnsureOpenAIResponsesImageGenerationTool_NoTools(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "input": "draw a cat",
+ }
+
+ modified := ensureOpenAIResponsesImageGenerationTool(reqBody)
+ require.True(t, modified)
+
+ tools, ok := reqBody["tools"].([]any)
+ require.True(t, ok)
+ require.Len(t, tools, 1)
+ tool, ok := tools[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "image_generation", tool["type"])
+ require.Equal(t, "png", tool["output_format"])
+}
+
+func TestEnsureOpenAIResponsesImageGenerationTool_SkipsSpark(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.3-codex-spark",
+ "input": "draw a cat",
+ }
+
+ modified := ensureOpenAIResponsesImageGenerationTool(reqBody)
+ require.False(t, modified)
+ require.NotContains(t, reqBody, "tools")
+}
+
+func TestEnsureOpenAIResponsesImageGenerationTool_AppendsToExistingTools(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "tools": []any{
+ map[string]any{"type": "web_search"},
+ },
+ }
+
+ modified := ensureOpenAIResponsesImageGenerationTool(reqBody)
+ require.True(t, modified)
+
+ tools, ok := reqBody["tools"].([]any)
+ require.True(t, ok)
+ require.Len(t, tools, 2)
+ first, ok := tools[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "web_search", first["type"])
+ second, ok := tools[1].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "image_generation", second["type"])
+ require.Equal(t, "png", second["output_format"])
+}
+
+func TestEnsureOpenAIResponsesImageGenerationTool_PreservesExistingImageTool(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "tools": []any{
+ map[string]any{"type": "image_generation", "output_format": "webp"},
+ map[string]any{"type": "web_search"},
+ },
+ }
+
+ modified := ensureOpenAIResponsesImageGenerationTool(reqBody)
+ require.False(t, modified)
+
+ tools, ok := reqBody["tools"].([]any)
+ require.True(t, ok)
+ require.Len(t, tools, 2)
+ tool, ok := tools[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "webp", tool["output_format"])
+}
+
+func TestApplyCodexImageGenerationBridgeInstructions_AppendsBridgeOnce(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "instructions": "existing instructions",
+ "tools": []any{
+ map[string]any{"type": "image_generation", "output_format": "png"},
+ },
+ }
+
+ modified := applyCodexImageGenerationBridgeInstructions(reqBody)
+ require.True(t, modified)
+
+ instructions, ok := reqBody["instructions"].(string)
+ require.True(t, ok)
+ require.Contains(t, instructions, "existing instructions")
+ require.Contains(t, instructions, codexImageGenerationBridgeMarker)
+ require.Contains(t, instructions, "Responses native `image_generation` tool")
+
+ modified = applyCodexImageGenerationBridgeInstructions(reqBody)
+ require.False(t, modified)
+}
+
+func TestApplyCodexImageGenerationBridgeInstructions_SkipsSpark(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.3-codex-spark",
+ "instructions": "existing instructions",
+ "tools": []any{
+ map[string]any{"type": "image_generation", "output_format": "png"},
+ },
+ }
+
+ modified := applyCodexImageGenerationBridgeInstructions(reqBody)
+ require.False(t, modified)
+ require.Equal(t, "existing instructions", reqBody["instructions"])
+}
+
+func TestApplyCodexImageGenerationBridgeInstructions_SkipsWithoutImageTool(t *testing.T) {
+ reqBody := map[string]any{
+ "instructions": "existing instructions",
+ "tools": []any{
+ map[string]any{"type": "web_search"},
+ },
+ }
+
+ modified := applyCodexImageGenerationBridgeInstructions(reqBody)
+ require.False(t, modified)
+ require.Equal(t, "existing instructions", reqBody["instructions"])
+}
+
+func TestValidateCodexSparkInputRejectsInputImage(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.3-codex-spark",
+ "input": []any{
+ map[string]any{
+ "role": "user",
+ "content": []any{
+ map[string]any{"type": "input_text", "text": "describe"},
+ map[string]any{"type": "input_image", "image_url": "data:image/png;base64,aGVsbG8="},
+ },
+ },
+ },
+ }
+
+ err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "does not support image input")
+}
+
+func TestValidateCodexSparkInputRejectsChatImageURL(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.3-codex-spark",
+ "messages": []any{
+ map[string]any{
+ "role": "user",
+ "content": []any{
+ map[string]any{"type": "text", "text": "describe"},
+ map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,aGVsbG8="}},
+ },
+ },
+ },
+ }
+
+ err := validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark")
+ require.Error(t, err)
+}
+
+func TestValidateCodexSparkInputAllowsTextOnly(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.3-codex-spark",
+ "input": []any{
+ map[string]any{
+ "role": "user",
+ "content": []any{
+ map[string]any{"type": "input_text", "text": "hello"},
+ },
+ },
+ },
+ }
+
+ require.NoError(t, validateCodexSparkInput(reqBody, "gpt-5.3-codex-spark"))
+}
+
+func TestApplyCodexOAuthTransform_AddsSparkImageUnsupportedInstructions(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.3-codex-spark",
+ "instructions": "existing instructions",
+ "input": "hello",
+ }
+
+ result := applyCodexOAuthTransform(reqBody, true, false)
+ require.True(t, result.Modified)
+
+ instructions, ok := reqBody["instructions"].(string)
+ require.True(t, ok)
+ require.Contains(t, instructions, "existing instructions")
+ require.Contains(t, instructions, codexSparkImageUnsupportedMarker)
+ require.Contains(t, instructions, "does not support image generation")
+ require.Contains(t, instructions, "switch to a non-Spark Codex model")
+ require.NotContains(t, instructions, codexImageGenerationBridgeMarker)
+}
+
+func TestApplyCodexOAuthTransform_DoesNotAddSparkImageUnsupportedForNonSpark(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "instructions": "existing instructions",
+ "input": "hello",
+ }
+
+ applyCodexOAuthTransform(reqBody, true, false)
+ instructions, ok := reqBody["instructions"].(string)
+ require.True(t, ok)
+ require.NotContains(t, instructions, codexSparkImageUnsupportedMarker)
+}
+
+func TestNormalizeOpenAIResponsesImageOnlyModel_BuildsImageToolRequest(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-image-2",
+ "prompt": "draw a cat",
+ "size": "1024x1024",
+ "output_format": "png",
+ }
+
+ modified := normalizeOpenAIResponsesImageOnlyModel(reqBody)
+ require.True(t, modified)
+ require.Equal(t, openAIImagesResponsesMainModel, reqBody["model"])
+ require.Equal(t, "draw a cat", reqBody["input"])
+ _, hasPrompt := reqBody["prompt"]
+ require.False(t, hasPrompt)
+ _, hasTopLevelSize := reqBody["size"]
+ require.False(t, hasTopLevelSize)
+
+ tools, ok := reqBody["tools"].([]any)
+ require.True(t, ok)
+ require.Len(t, tools, 1)
+ tool, ok := tools[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "image_generation", tool["type"])
+ require.Equal(t, "gpt-image-2", tool["model"])
+ require.Equal(t, "1024x1024", tool["size"])
+ require.Equal(t, "png", tool["output_format"])
+
+ choice, ok := reqBody["tool_choice"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "image_generation", choice["type"])
+}
+
+func TestNormalizeOpenAIResponsesImageOnlyModel_PreservesExistingImageTool(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-image-2",
+ "input": "draw a cat",
+ "tools": []any{
+ map[string]any{
+ "type": "image_generation",
+ "model": "gpt-image-1.5",
+ },
+ },
+ "tool_choice": "auto",
+ }
+
+ modified := normalizeOpenAIResponsesImageOnlyModel(reqBody)
+ require.True(t, modified)
+ require.Equal(t, openAIImagesResponsesMainModel, reqBody["model"])
+ require.Equal(t, "auto", reqBody["tool_choice"])
+
+ tools, ok := reqBody["tools"].([]any)
+ require.True(t, ok)
+ require.Len(t, tools, 1)
+ tool, ok := tools[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "gpt-image-1.5", tool["model"])
+}
+
+func TestValidateOpenAIResponsesImageModel_RejectsImageOnlyModel(t *testing.T) {
+ err := validateOpenAIResponsesImageModel(map[string]any{
+ "tools": []any{
+ map[string]any{"type": "image_generation"},
+ },
+ }, "gpt-image-2")
+
+ require.ErrorContains(t, err, `/v1/responses image_generation requests require a Responses-capable text model`)
+}
+
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
// 空 input 应保持为空且不触发异常。
@@ -240,15 +809,13 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
"gpt 5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini",
"gpt 5.4 mini": "gpt-5.4-mini",
- "gpt-5.4-nano": "gpt-5.4-nano",
- "gpt 5.4 nano": "gpt-5.4-nano",
"gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
- "gpt-5.3-codex-spark": "gpt-5.3-codex",
- "gpt 5.3 codex spark": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-high": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
+ "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
+ "gpt 5.3 codex spark": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt 5.3 codex": "gpt-5.3-codex",
}
@@ -257,6 +824,26 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
}
}
+func TestNormalizeCodexModel_RemovedModelsFallbackToSupportedTargets(t *testing.T) {
+ cases := map[string]string{
+ "": "gpt-5.4",
+ "gpt-5": "gpt-5.4",
+ "gpt-5-mini": "gpt-5.4",
+ "gpt-5-nano": "gpt-5.4",
+ "gpt-5.1": "gpt-5.4",
+ "gpt-5.1-codex": "gpt-5.3-codex",
+ "gpt-5.1-codex-max": "gpt-5.3-codex",
+ "gpt-5.1-codex-mini": "gpt-5.3-codex",
+ "gpt-5.2-codex": "gpt-5.2",
+ "codex-mini-latest": "gpt-5.3-codex",
+ "gpt-5-codex": "gpt-5.3-codex",
+ }
+
+ for input, expected := range cases {
+ require.Equal(t, expected, normalizeCodexModel(input))
+ }
+}
+
func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.3-codex-spark",
@@ -501,6 +1088,27 @@ func TestApplyCodexOAuthTransform_StripsPromptCacheRetention(t *testing.T) {
"prompt_cache_retention must be stripped before forwarding to Codex upstream")
}
+func TestApplyCodexOAuthTransform_StripsChatGPTInternalUnsupportedFields(t *testing.T) {
+ reqBody := map[string]any{
+ "model": "gpt-5.4",
+ "user": "user_123",
+ "metadata": map[string]any{"trace_id": "abc"},
+ "prompt_cache_retention": "24h",
+ "safety_identifier": "sid",
+ "stream_options": map[string]any{"include_usage": true},
+ "input": []any{
+ map[string]any{"role": "user", "content": "hi"},
+ },
+ }
+
+ result := applyCodexOAuthTransform(reqBody, true, false)
+
+ require.True(t, result.Modified)
+ for _, field := range openAIChatGPTInternalUnsupportedFields {
+ require.NotContains(t, reqBody, field)
+ }
+}
+
func TestApplyCodexOAuthTransform_ExtractsSystemMessages(t *testing.T) {
reqBody := map[string]any{
"model": "gpt-5.1",
@@ -547,3 +1155,56 @@ func TestIsInstructionsEmpty(t *testing.T) {
})
}
}
+
+func TestFilterCodexInput_DropsReasoningItemsRegardlessOfPreserveReferences(t *testing.T) {
+ // Reasoning items in input[] reference rs_* IDs that were emitted by
+ // chatgpt.com under store=false (forced by applyCodexOAuthTransform).
+ // They are never persisted upstream, so forwarding them produces a
+ // guaranteed 404 ("Item with id 'rs_...' not found"). Drop them
+ // regardless of preserveReferences. See: Wei-Shaw/sub2api issue #1957.
+
+ build := func() []any {
+ return []any{
+ map[string]any{"type": "message", "id": "msg_0", "role": "user", "content": "hi"},
+ map[string]any{
+ "type": "reasoning",
+ "id": "rs_0672f12450da0b9c0169f07220a6c08198b68c2455ced99344",
+ "summary": []any{},
+ },
+ map[string]any{"type": "function_call", "id": "fc_1", "call_id": "call_1", "name": "tool"},
+ map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "{}"},
+ }
+ }
+
+ for _, preserve := range []bool{true, false} {
+ preserve := preserve
+ t.Run(fmt.Sprintf("preserveReferences=%v", preserve), func(t *testing.T) {
+ filtered := filterCodexInput(build(), preserve)
+
+ for _, raw := range filtered {
+ item, ok := raw.(map[string]any)
+ require.True(t, ok)
+ require.NotEqual(t, "reasoning", item["type"],
+ "reasoning items must be dropped from input on the OAuth path")
+ if id, ok := item["id"].(string); ok {
+ require.False(t, strings.HasPrefix(id, "rs_"),
+ "no item carrying an rs_* id should survive the filter")
+ }
+ }
+
+ // Sanity check: the non-reasoning items should still be present.
+ gotTypes := make(map[string]int)
+ for _, raw := range filtered {
+ item, ok := raw.(map[string]any)
+ require.True(t, ok)
+ typ, ok := item["type"].(string)
+ require.True(t, ok)
+ gotTypes[typ]++
+ }
+ require.Equal(t, 1, gotTypes["message"])
+ require.Equal(t, 1, gotTypes["function_call"])
+ require.Equal(t, 1, gotTypes["function_call_output"])
+ require.Equal(t, 0, gotTypes["reasoning"])
+ })
+ }
+}
diff --git a/backend/internal/service/openai_compact_model_mapping_test.go b/backend/internal/service/openai_compact_model_mapping_test.go
new file mode 100644
index 00000000..fc408e64
--- /dev/null
+++ b/backend/internal/service/openai_compact_model_mapping_test.go
@@ -0,0 +1,135 @@
+package service
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestOpenAIGatewayService_Forward_CompactOnlyModelMappingOverridesOAuthUpstreamModel(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ body := []byte(`{"model":"gpt-5.4","stream":false,"instructions":"compact-test","input":"hello"}`)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact-map"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"resp_123","status":"completed","model":"gpt-5.4-openai-compact","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
+ }}
+
+ svc := &OpenAIGatewayService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ "chatgpt_account_id": "chatgpt-acc",
+ "compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
+ },
+ Status: StatusActive,
+ Schedulable: true,
+ }
+
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "gpt-5.4", result.Model)
+ require.Equal(t, "gpt-5.4-openai-compact", result.UpstreamModel)
+ require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
+}
+
+func TestOpenAIGatewayService_Forward_NonCompactRequestIgnoresCompactOnlyModelMapping(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ body := []byte(`{"model":"gpt-5.4","stream":false,"instructions":"normal-test","input":"hello"}`)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-normal-map"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"resp_124","status":"completed","model":"gpt-5.4","output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)),
+ }}
+
+ svc := &OpenAIGatewayService{httpUpstream: upstream}
+ account := &Account{
+ ID: 2,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ "chatgpt_account_id": "chatgpt-acc",
+ "compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
+ },
+ Status: StatusActive,
+ Schedulable: true,
+ }
+
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "gpt-5.4", result.Model)
+ require.Equal(t, "gpt-5.4", result.UpstreamModel)
+ require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
+}
+
+func TestOpenAIGatewayService_OAuthPassthrough_CompactOnlyModelMappingOverridesUpstreamModel(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", bytes.NewReader(nil))
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ originalBody := []byte(`{"model":"gpt-5.4","stream":true,"store":true,"instructions":"compact-pass","input":[{"type":"text","text":"compact me"}]}`)
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-compact-pass-map"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"cmp_124","model":"gpt-5.4-openai-compact","usage":{"input_tokens":2,"output_tokens":3}}`)),
+ }}
+
+ svc := &OpenAIGatewayService{httpUpstream: upstream}
+ account := &Account{
+ ID: 3,
+ Name: "openai-oauth-pass",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ "chatgpt_account_id": "chatgpt-acc",
+ "compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
+ },
+ Extra: map[string]any{"openai_passthrough": true},
+ Status: StatusActive,
+ Schedulable: true,
+ }
+
+ result, err := svc.Forward(context.Background(), c, account, originalBody)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "gpt-5.4", result.Model)
+ require.Equal(t, "gpt-5.4-openai-compact", result.UpstreamModel)
+ require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
+ require.Equal(t, "gpt-5.4", gjson.GetBytes(rec.Body.Bytes(), "model").String())
+}
diff --git a/backend/internal/service/openai_compact_probe.go b/backend/internal/service/openai_compact_probe.go
new file mode 100644
index 00000000..e8deff2d
--- /dev/null
+++ b/backend/internal/service/openai_compact_probe.go
@@ -0,0 +1,120 @@
+package service
+
+import (
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+)
+
+const (
+ // AccountTestModeDefault drives the standard /responses connection test.
+ AccountTestModeDefault = "default"
+ // AccountTestModeCompact drives the /responses/compact compact-probe test.
+ AccountTestModeCompact = "compact"
+)
+
+func normalizeAccountTestMode(mode string) string {
+ switch strings.ToLower(strings.TrimSpace(mode)) {
+ case AccountTestModeCompact:
+ return AccountTestModeCompact
+ default:
+ return AccountTestModeDefault
+ }
+}
+
+func createOpenAICompactProbePayload(model string) map[string]any {
+ return map[string]any{
+ "model": strings.TrimSpace(model),
+ "instructions": "You are a helpful coding assistant.",
+ "input": []any{
+ map[string]any{
+ "type": "message",
+ "role": "user",
+ "content": "Respond with OK.",
+ },
+ },
+ }
+}
+
+func shouldMarkOpenAICompactUnsupported(status int, body []byte) bool {
+ switch status {
+ case http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusNotImplemented:
+ return true
+ case http.StatusBadRequest, http.StatusForbidden, http.StatusUnprocessableEntity:
+ lower := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body) + " " + string(body)))
+ if strings.Contains(lower, "compact") {
+ for _, keyword := range []string{
+ "unsupported",
+ "not support",
+ "does not support",
+ "not available",
+ "disabled",
+ } {
+ if strings.Contains(lower, keyword) {
+ return true
+ }
+ }
+ }
+ }
+ return false
+}
+
+func buildOpenAICompactProbeExtraUpdates(resp *http.Response, body []byte, probeErr error, now time.Time) map[string]any {
+ updates := map[string]any{
+ "openai_compact_checked_at": now.Format(time.RFC3339),
+ "openai_compact_last_status": nil,
+ }
+
+ if resp != nil {
+ updates["openai_compact_last_status"] = resp.StatusCode
+ }
+
+ switch {
+ case probeErr != nil:
+ updates["openai_compact_last_error"] = truncateString(sanitizeUpstreamErrorMessage(probeErr.Error()), 2048)
+ case resp == nil:
+ updates["openai_compact_last_error"] = "compact probe failed"
+ default:
+ errMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
+ if errMsg == "" && len(body) > 0 {
+ errMsg = strings.TrimSpace(string(body))
+ }
+ if errMsg == "" && (resp.StatusCode < 200 || resp.StatusCode >= 300) {
+ errMsg = "HTTP " + strconv.Itoa(resp.StatusCode)
+ }
+ errMsg = truncateString(sanitizeUpstreamErrorMessage(errMsg), 2048)
+ if resp.StatusCode >= 200 && resp.StatusCode < 300 {
+ updates["openai_compact_supported"] = true
+ updates["openai_compact_last_error"] = ""
+ } else {
+ if shouldMarkOpenAICompactUnsupported(resp.StatusCode, body) {
+ updates["openai_compact_supported"] = false
+ }
+ updates["openai_compact_last_error"] = errMsg
+ }
+ }
+
+ return updates
+}
+
+func mergeExtraUpdates(base map[string]any, more map[string]any) map[string]any {
+ if len(base) == 0 && len(more) == 0 {
+ return nil
+ }
+ out := make(map[string]any, len(base)+len(more))
+ for key, value := range base {
+ out[key] = value
+ }
+ for key, value := range more {
+ out[key] = value
+ }
+ return out
+}
+
+func compactProbeSessionID(accountID int64) string {
+ if accountID <= 0 {
+ return "probe_compact"
+ }
+ return "probe_compact_" + strconv.FormatInt(accountID, 10)
+}
diff --git a/backend/internal/service/openai_compact_probe_test.go b/backend/internal/service/openai_compact_probe_test.go
new file mode 100644
index 00000000..fe3ba0e8
--- /dev/null
+++ b/backend/internal/service/openai_compact_probe_test.go
@@ -0,0 +1,122 @@
+package service
+
+import (
+ "errors"
+ "net/http"
+ "testing"
+ "time"
+)
+
+func TestNormalizeAccountTestMode(t *testing.T) {
+ tests := []struct {
+ input string
+ want string
+ }{
+ {input: "", want: AccountTestModeDefault},
+ {input: "default", want: AccountTestModeDefault},
+ {input: " compact ", want: AccountTestModeCompact},
+ {input: "COMPACT", want: AccountTestModeCompact},
+ {input: "unknown", want: AccountTestModeDefault},
+ }
+
+ for _, tt := range tests {
+ if got := normalizeAccountTestMode(tt.input); got != tt.want {
+ t.Fatalf("normalizeAccountTestMode(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ }
+}
+
+func TestBuildOpenAICompactProbeExtraUpdates_SuccessMarksSupported(t *testing.T) {
+ now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
+ updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusOK}, []byte(`{"id":"cmp_1"}`), nil, now)
+
+ if got := updates["openai_compact_supported"]; got != true {
+ t.Fatalf("openai_compact_supported = %v, want true", got)
+ }
+ if got := updates["openai_compact_last_status"]; got != http.StatusOK {
+ t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusOK)
+ }
+ if got := updates["openai_compact_last_error"]; got != "" {
+ t.Fatalf("openai_compact_last_error = %v, want empty string", got)
+ }
+ if got := updates["openai_compact_checked_at"]; got != now.Format(time.RFC3339) {
+ t.Fatalf("openai_compact_checked_at = %v, want %s", got, now.Format(time.RFC3339))
+ }
+}
+
+func TestBuildOpenAICompactProbeExtraUpdates_404MarksUnsupported(t *testing.T) {
+ now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
+ body := []byte(`404 page not found`)
+ updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusNotFound}, body, nil, now)
+
+ if got := updates["openai_compact_supported"]; got != false {
+ t.Fatalf("openai_compact_supported = %v, want false", got)
+ }
+ if got := updates["openai_compact_last_status"]; got != http.StatusNotFound {
+ t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusNotFound)
+ }
+}
+
+func TestBuildOpenAICompactProbeExtraUpdates_502DoesNotMarkUnsupported(t *testing.T) {
+ now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
+ updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusBadGateway}, []byte(`Upstream request failed`), nil, now)
+
+ if _, exists := updates["openai_compact_supported"]; exists {
+ t.Fatalf("did not expect openai_compact_supported for 502 response")
+ }
+ if got := updates["openai_compact_last_status"]; got != http.StatusBadGateway {
+ t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusBadGateway)
+ }
+}
+
+func TestBuildOpenAICompactProbeExtraUpdates_RequestErrorDoesNotMarkUnsupported(t *testing.T) {
+ now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
+ updates := buildOpenAICompactProbeExtraUpdates(nil, nil, errors.New("dial tcp timeout"), now)
+
+ if _, exists := updates["openai_compact_supported"]; exists {
+ t.Fatalf("did not expect openai_compact_supported for request error")
+ }
+ if got, exists := updates["openai_compact_last_status"]; !exists || got != nil {
+ t.Fatalf("openai_compact_last_status = %v, want nil key", got)
+ }
+ if got := updates["openai_compact_last_error"]; got == "" {
+ t.Fatalf("expected openai_compact_last_error to be populated")
+ }
+}
+
+func TestBuildOpenAICompactProbeExtraUpdates_NoResponseClearsLastStatus(t *testing.T) {
+ now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
+ updates := buildOpenAICompactProbeExtraUpdates(nil, nil, nil, now)
+
+ if got, exists := updates["openai_compact_last_status"]; !exists || got != nil {
+ t.Fatalf("openai_compact_last_status = %v, want nil key", got)
+ }
+ if got := updates["openai_compact_last_error"]; got != "compact probe failed" {
+ t.Fatalf("openai_compact_last_error = %v, want compact probe failed", got)
+ }
+}
+
+func TestBuildOpenAICompactProbeExtraUpdates_UnknownModelDoesNotMarkUnsupported(t *testing.T) {
+ now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
+ body := []byte(`{"error":{"message":"unknown model gpt-5.4-openai-compact"}}`)
+ updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusBadRequest}, body, nil, now)
+
+ if _, exists := updates["openai_compact_supported"]; exists {
+ t.Fatalf("did not expect openai_compact_supported for unknown-model diagnostics")
+ }
+ if got := updates["openai_compact_last_status"]; got != http.StatusBadRequest {
+ t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusBadRequest)
+ }
+}
+
+func TestBuildOpenAICompactProbeExtraUpdates_EmptyFailureBodyFallsBackToHTTPStatus(t *testing.T) {
+ now := time.Date(2026, 4, 10, 10, 0, 0, 0, time.UTC)
+ updates := buildOpenAICompactProbeExtraUpdates(&http.Response{StatusCode: http.StatusServiceUnavailable}, nil, nil, now)
+
+ if got := updates["openai_compact_last_status"]; got != http.StatusServiceUnavailable {
+ t.Fatalf("openai_compact_last_status = %v, want %d", got, http.StatusServiceUnavailable)
+ }
+ if got := updates["openai_compact_last_error"]; got != "HTTP 503" {
+ t.Fatalf("openai_compact_last_error = %v, want HTTP 503", got)
+ }
+}
diff --git a/backend/internal/service/openai_compat_prompt_cache_key.go b/backend/internal/service/openai_compat_prompt_cache_key.go
index 88e16a4d..fcd27f19 100644
--- a/backend/internal/service/openai_compat_prompt_cache_key.go
+++ b/backend/internal/service/openai_compat_prompt_cache_key.go
@@ -10,8 +10,14 @@ import (
const compatPromptCacheKeyPrefix = "compat_cc_"
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
- switch normalizeCodexModel(strings.TrimSpace(model)) {
- case "gpt-5.4", "gpt-5.3-codex":
+ trimmed := strings.TrimSpace(strings.ToLower(model))
+ // 仅对 Codex OAuth 路径支持的 GPT-5 族开启自动注入,避免 normalizeCodexModel
+ // 的默认兜底把任意模型(如 gpt-4o、claude-*)误判为 gpt-5.4。
+ if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
+ return false
+ }
+ switch normalizeCodexModel(trimmed) {
+ case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark":
return true
default:
return false
diff --git a/backend/internal/service/openai_fast_policy_test.go b/backend/internal/service/openai_fast_policy_test.go
new file mode 100644
index 00000000..b52da614
--- /dev/null
+++ b/backend/internal/service/openai_fast_policy_test.go
@@ -0,0 +1,286 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type openAIFastPolicyRepoStub struct {
+ values map[string]string
+}
+
+func (s *openAIFastPolicyRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *openAIFastPolicyRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ if v, ok := s.values[key]; ok {
+ return v, nil
+ }
+ return "", ErrSettingNotFound
+}
+
+func (s *openAIFastPolicyRepoStub) Set(ctx context.Context, key, value string) error {
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ return nil
+}
+
+func (s *openAIFastPolicyRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ panic("unexpected GetMultiple call")
+}
+
+func (s *openAIFastPolicyRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *openAIFastPolicyRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *openAIFastPolicyRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolicySettings) *OpenAIGatewayService {
+ t.Helper()
+ repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
+ if settings != nil {
+ raw, err := json.Marshal(settings)
+ require.NoError(t, err)
+ repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
+ }
+ return &OpenAIGatewayService{
+ settingService: NewSettingService(repo, &config.Config{}),
+ }
+}
+
+func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast
+ // 是用户级开关,与 model 正交。
+ // gpt-5.5 + priority → filter
+ action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
+ require.Equal(t, BetaPolicyActionFilter, action)
+
+ // gpt-5.5-turbo → filter
+ action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
+ require.Equal(t, BetaPolicyActionFilter, action)
+
+ // gpt-4 + priority → filter(默认策略覆盖所有模型)
+ action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
+ require.Equal(t, BetaPolicyActionFilter, action)
+
+ // gpt-5.5 + flex → pass (tier doesn't match)
+ action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
+ require.Equal(t, BetaPolicyActionPass, action)
+
+ // empty tier → pass
+ action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", "")
+ require.Equal(t, BetaPolicyActionPass, action)
+}
+
+func TestEvaluateOpenAIFastPolicy_BlockRuleCarriesMessage(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionBlock,
+ Scope: BetaPolicyScopeAll,
+ ErrorMessage: "fast mode is not allowed",
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ action, msg := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
+ require.Equal(t, BetaPolicyActionBlock, action)
+ require.Equal(t, "fast mode is not allowed", msg)
+}
+
+func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierAny,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeOAuth,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+
+ // OAuth account → rule matches
+ oauthAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth}
+ action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), oauthAccount, "gpt-4", OpenAIFastTierPriority)
+ require.Equal(t, BetaPolicyActionFilter, action)
+
+ // API Key account → rule skipped → pass
+ apiKeyAccount := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+ action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), apiKeyAccount, "gpt-4", OpenAIFastTierPriority)
+ require.Equal(t, BetaPolicyActionPass, action)
+}
+
+func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // gpt-5.5 fast → service_tier stripped
+ body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err)
+ require.NotContains(t, string(updated), `"service_tier"`)
+
+ // Client sending "fast" (alias for priority) also filtered
+ body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err)
+ require.NotContains(t, string(updated), `"service_tier"`)
+
+ // gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除
+ body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
+ require.NoError(t, err)
+ require.NotContains(t, string(updated), `"service_tier"`)
+
+ // No service_tier → no-op
+ body = []byte(`{"model":"gpt-5.5"}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err)
+ require.Equal(t, string(body), string(updated))
+}
+
+// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
+// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被
+// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。
+func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ for _, tier := range []string{"auto", "default", "scale"} {
+ body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err, "tier %q should pass without error", tier)
+ require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
+ "tier %q should be preserved in body under default rule", tier)
+ }
+
+ // evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
+ for _, tier := range []string{"auto", "default", "scale"} {
+ action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
+ require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
+ }
+}
+
+// TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers 验证管理员显式配置
+// ServiceTier=all + Action=filter 规则后,auto/default/scale 等官方 tier 也会
+// 被剥离。这是符合预期的——首条匹配 short-circuit,"all" 覆盖任意已识别 tier。
+func TestApplyOpenAIFastPolicyToBody_AllRuleStripsOfficialTiers(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierAny,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ for _, tier := range []string{"auto", "default", "scale", "priority", "flex"} {
+ body := []byte(`{"model":"gpt-5.5","service_tier":"` + tier + `"}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err)
+ require.NotContains(t, string(updated), `"service_tier"`,
+ "tier %q should be stripped under ServiceTier=all + filter rule", tier)
+ }
+}
+
+// TestApplyOpenAIFastPolicyToBody_UnknownTierStripped 验证真未知 tier 仍被剥离
+// (normalize 返回 nil → normalizeResponsesBodyServiceTier 删除字段;
+// applyOpenAIFastPolicyToBody 在 normTier 为空时直接 no-op,因为字段已不可能存在
+// 于经过前置归一化的请求里。这里直接调 apply 验证它对未识别值不会异常)。
+func TestApplyOpenAIFastPolicyToBody_UnknownTierStripped(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // normalize 阶段会将未知值剥离
+ require.Nil(t, normalizeOpenAIServiceTier("xxx"))
+
+ // applyOpenAIFastPolicyToBody 收到未识别 tier 时不报错,body 透传不变
+ // (不属于本函数职责——上层 normalizeResponsesBodyServiceTier 已剥离)
+ body := []byte(`{"model":"gpt-5.5","service_tier":"xxx"}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err)
+ require.Equal(t, string(body), string(updated))
+}
+
+func TestApplyOpenAIFastPolicyToBody_BlockReturnsTypedError(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionBlock,
+ Scope: BetaPolicyScopeAll,
+ ErrorMessage: "fast mode is blocked for gpt-5.5",
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ body := []byte(`{"model":"gpt-5.5","service_tier":"priority"}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.Error(t, err)
+ var blocked *OpenAIFastBlockedError
+ require.True(t, errors.As(err, &blocked))
+ require.Contains(t, blocked.Message, "fast mode is blocked")
+ require.Equal(t, string(body), string(updated)) // body not mutated on block
+}
+
+func TestSetOpenAIFastPolicySettings_Validation(t *testing.T) {
+ repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
+ svc := NewSettingService(repo, &config.Config{})
+
+ // Invalid action rejected
+ err := svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: "bogus",
+ Scope: BetaPolicyScopeAll,
+ }},
+ })
+ require.Error(t, err)
+
+ // Invalid service_tier rejected
+ err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: "turbo",
+ Action: BetaPolicyActionPass,
+ Scope: BetaPolicyScopeAll,
+ }},
+ })
+ require.Error(t, err)
+
+ // Valid settings persisted
+ err = svc.SetOpenAIFastPolicySettings(context.Background(), &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ }},
+ })
+ require.NoError(t, err)
+
+ got, err := svc.GetOpenAIFastPolicySettings(context.Background())
+ require.NoError(t, err)
+ require.Len(t, got.Rules, 1)
+ require.Equal(t, OpenAIFastTierPriority, got.Rules[0].ServiceTier)
+}
diff --git a/backend/internal/service/openai_fast_policy_ws_test.go b/backend/internal/service/openai_fast_policy_ws_test.go
new file mode 100644
index 00000000..3316a242
--- /dev/null
+++ b/backend/internal/service/openai_fast_policy_ws_test.go
@@ -0,0 +1,1018 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
+ coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+// --- Helper-level (unit) tests for applyOpenAIFastPolicyToWSResponseCreate ---
+
+func TestWSResponseCreate_FilterStripsServiceTier(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority","input":[{"type":"input_text","text":"hi"}]}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.NotContains(t, string(updated), `"service_tier"`, "filter action should strip service_tier")
+ // Other fields preserved.
+ require.Equal(t, "response.create", gjson.GetBytes(updated, "type").String())
+ require.Equal(t, "gpt-5.5", gjson.GetBytes(updated, "model").String())
+ require.Equal(t, "hi", gjson.GetBytes(updated, "input.0.text").String())
+}
+
+func TestWSResponseCreate_FastNormalizedToPriorityThenFiltered(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Verbatim "fast" → normalized to "priority" → matches default rule → filter.
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"fast"}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.NotContains(t, string(updated), `"service_tier"`)
+
+ // Mixed-case + whitespace variant should also normalize and filter.
+ frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":" Fast "}`)
+ updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.NotContains(t, string(updated), `"service_tier"`)
+}
+
+func TestWSResponseCreate_FlexPassThrough(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Default policy targets priority only; flex is left untouched.
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.Equal(t, "flex", gjson.GetBytes(updated, "service_tier").String(), "flex frames must reach upstream untouched under default policy")
+}
+
+func TestWSResponseCreate_BlockReturnsTypedError(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionBlock,
+ Scope: BetaPolicyScopeAll,
+ ErrorMessage: "ws fast blocked",
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.NotNil(t, blocked)
+ require.Equal(t, "ws fast blocked", blocked.Message)
+ // On block, payload returned unchanged so caller can inspect / log it.
+ require.Equal(t, string(frame), string(updated))
+}
+
+func TestWSResponseCreate_NoServiceTierUntouched(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","input":[]}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.Equal(t, string(frame), string(updated), "no service_tier present must result in zero mutation")
+}
+
+func TestWSResponseCreate_NonResponseCreateFrameUntouched(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ ModelWhitelist: []string{"*"},
+ FallbackAction: BetaPolicyActionFilter,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // response.cancel happens to carry a service_tier-shaped field — must not be touched.
+ frame := []byte(`{"type":"response.cancel","service_tier":"priority"}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.Equal(t, string(frame), string(updated))
+}
+
+// TestWSResponseCreate_EmptyTypeFrameUntouched is the A1 regression: the
+// helper used to treat empty type as response.create, which risked stripping
+// fields from malformed / unknown client events. After the A1 fix only a
+// strict "response.create" match triggers policy.
+func TestWSResponseCreate_EmptyTypeFrameUntouched(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ ModelWhitelist: []string{"*"},
+ FallbackAction: BetaPolicyActionFilter,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Frame with no "type" field: must pass through completely unchanged
+ // even with a service_tier-shaped field present.
+ frame := []byte(`{"service_tier":"priority","model":"gpt-5.5"}`)
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.Equal(t, string(frame), string(updated), "empty type must NOT be policy-checked — Realtime spec requires type, malformed frames are passed through")
+
+ // Explicit empty string also passes through.
+ frame = []byte(`{"type":"","service_tier":"priority","model":"gpt-5.5"}`)
+ updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.Equal(t, string(frame), string(updated))
+}
+
+// TestBuildOpenAIFastPolicyBlockedWSEvent_HasEventIDAndCode is the B1
+// regression: the rendered Realtime error event must carry a non-empty
+// event_id (so clients can correlate the rejection) and a stable error.code
+// ("policy_violation"). The HTTP-side equivalent is the 403 permission_error
+// JSON body emitted by writeOpenAIFastPolicyBlockedResponse.
+func TestBuildOpenAIFastPolicyBlockedWSEvent_HasEventIDAndCode(t *testing.T) {
+ bytes := buildOpenAIFastPolicyBlockedWSEvent(&OpenAIFastBlockedError{Message: "blocked because reasons"})
+ require.NotNil(t, bytes)
+
+ require.Equal(t, "error", gjson.GetBytes(bytes, "type").String())
+ require.Equal(t, "invalid_request_error", gjson.GetBytes(bytes, "error.type").String())
+ require.Equal(t, "policy_violation", gjson.GetBytes(bytes, "error.code").String())
+ require.Equal(t, "blocked because reasons", gjson.GetBytes(bytes, "error.message").String())
+
+ eventID := gjson.GetBytes(bytes, "event_id").String()
+ require.NotEmpty(t, eventID, "event_id must be present so clients can correlate the rejection in their logs")
+ require.True(t, strings.HasPrefix(eventID, "evt_"), "event_id should follow the evt_ Realtime convention; got %q", eventID)
+
+ // Sanity check: two consecutive events get distinct IDs.
+ other := buildOpenAIFastPolicyBlockedWSEvent(&OpenAIFastBlockedError{Message: "second"})
+ otherID := gjson.GetBytes(other, "event_id").String()
+ require.NotEqual(t, eventID, otherID, "event_id must be random per-event")
+}
+
+// TestBuildOpenAIFastPolicyBlockedWSEvent_NilSafe ensures the helper returns
+// nil for a nil error (defensive guard for callers that always invoke it).
+func TestBuildOpenAIFastPolicyBlockedWSEvent_NilSafe(t *testing.T) {
+ require.Nil(t, buildOpenAIFastPolicyBlockedWSEvent(nil))
+}
+
+// --- D5: passthrough wrapper FrameConn — capturedSessionModel fallback ---
+
+// fakePassthroughFrameConn replays a fixed sequence of client frames into the
+// policy-enforcing wrapper, then returns io.EOF. Captures all Write attempts
+// for write-side assertions (none expected in the D5 test, since the wrapper
+// only filters reads).
+type fakePassthroughFrameConn struct {
+ reads [][]byte
+ idx int
+ writes [][]byte
+ closeOnce bool
+}
+
+func (f *fakePassthroughFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ if f.idx >= len(f.reads) {
+ return coderws.MessageText, nil, errOpenAIWSConnClosed
+ }
+ payload := f.reads[f.idx]
+ f.idx++
+ return coderws.MessageText, payload, nil
+}
+
+func (f *fakePassthroughFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
+ cp := append([]byte(nil), payload...)
+ f.writes = append(f.writes, cp)
+ return nil
+}
+
+func (f *fakePassthroughFrameConn) Close() error {
+ f.closeOnce = true
+ return nil
+}
+
+// gpt55WhitelistFastPolicy 返回一份强制带 model whitelist 的策略,用于
+// 验证 capturedSessionModel fallback 的语义(默认策略 whitelist 为空时
+// fallback 路径无法被观察到)。
+func gpt55WhitelistFastPolicy() *OpenAIFastPolicySettings {
+ return &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ ModelWhitelist: []string{"gpt-5.5", "gpt-5.5*"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+}
+
+// TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel is
+// the D5 regression: in passthrough mode a follow-up response.create frame
+// without a "model" field must still hit the policy via the session-level
+// model captured from the first frame. Without the fallback an empty model
+// would miss a model whitelist and silently leak service_tier=priority
+// through to the upstream.
+func TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel(t *testing.T) {
+ // 此处特意使用带 whitelist 的策略,以便观察 capturedSessionModel
+ // fallback 是否生效(默认策略 whitelist 为空,fallback 与否结果一致,
+ // 不能用来覆盖此回归)。
+ svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Simulate the passthrough adapter capturing model from the first frame.
+ firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
+ capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
+ require.Equal(t, "gpt-5.5", capturedSessionModel)
+
+ // Follow-up frame deliberately omits "model" — Realtime allows this.
+ followupFrame := []byte(`{"type":"response.create","service_tier":"priority"}`)
+
+ inner := &fakePassthroughFrameConn{
+ reads: [][]byte{followupFrame},
+ }
+ wrapper := &openAIWSPolicyEnforcingFrameConn{
+ inner: inner,
+ filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
+ if msgType != coderws.MessageText {
+ return payload, nil, nil
+ }
+ model := openAIWSPassthroughPolicyModelForFrame(account, payload)
+ if model == "" {
+ model = capturedSessionModel
+ }
+ return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
+ },
+ }
+
+ // Read the follow-up frame through the wrapper. The policy MUST still
+ // trigger filter (gpt-5.5 + priority → filter), so the service_tier
+ // field is gone by the time the relay sees it.
+ _, payload, err := wrapper.ReadFrame(context.Background())
+ require.NoError(t, err)
+ require.NotContains(t, string(payload), `"service_tier"`,
+ "D5 regression: empty model on follow-up frame must fall back to capturedSessionModel; whitelist policy filters service_tier=priority for gpt-5.5")
+ require.Equal(t, "response.create", gjson.GetBytes(payload, "type").String())
+}
+
+// TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses pins the
+// inverse: when the wrapper has NO capturedSessionModel fallback (model is
+// empty per-frame and no fallback is wired up), the policy fails to match
+// the model whitelist and the frame leaks through unchanged. This documents
+// exactly the leak the D5 fix prevents.
+func TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses(t *testing.T) {
+ // 同样使用带 whitelist 的策略以观察 leak。
+ svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ followupFrame := []byte(`{"type":"response.create","service_tier":"priority"}`)
+ inner := &fakePassthroughFrameConn{reads: [][]byte{followupFrame}}
+ wrapper := &openAIWSPolicyEnforcingFrameConn{
+ inner: inner,
+ filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
+ // NO fallback — emulate the pre-fix behavior.
+ model := openAIWSPassthroughPolicyModelForFrame(account, payload)
+ return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
+ },
+ }
+
+ _, payload, err := wrapper.ReadFrame(context.Background())
+ require.NoError(t, err)
+ // Pre-fix: empty model misses ["gpt-5.5","gpt-5.5*"] whitelist → fallback=pass → service_tier kept.
+ require.Contains(t, string(payload), `"service_tier"`,
+ "sanity: without capturedSessionModel fallback the leak (D5) reproduces — confirms the fix is load-bearing")
+}
+
+// --- Ingress end-to-end test (filter path) ---
+
+// TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream wires up the
+// real ProxyResponsesWebSocketFromClient ingress session pipeline against a
+// captureConn upstream and asserts that a client frame with service_tier=fast
+// is normalized + filtered out before being written upstream. This is the
+// integration flavour of TestWSResponseCreate_FilterStripsServiceTier.
+func TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_ws_filter_1","model":"gpt-5.5","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
+ defaultJSON, err := json.Marshal(DefaultOpenAIFastPolicySettings())
+ require.NoError(t, err)
+ repo.values[SettingKeyOpenAIFastPolicySettings] = string(defaultJSON)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ settingService: NewSettingService(repo, cfg),
+ }
+
+ account := &Account{
+ ID: 901,
+ Name: "openai-ws-filter",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{"api_key": "sk-test"},
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() { _ = conn.CloseNow() }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ _, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() { _ = clientConn.CloseNow() }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"service_tier":"fast"}`)))
+ cancelWrite()
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, event, readErr := clientConn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, readErr)
+ require.Equal(t, "response.completed", gjson.GetBytes(event, "type").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ require.Len(t, captureConn.writes, 1, "上游应只收到一条 response.create")
+ upstream := captureConn.writes[0]
+ _, hasServiceTier := upstream["service_tier"]
+ require.False(t, hasServiceTier, "上游收到的 response.create 不应包含 service_tier 字段(已被 fast policy filter 删除)")
+ require.Equal(t, "response.create", upstream["type"])
+ require.Equal(t, "gpt-5.5", upstream["model"])
+}
+
+// TestWSResponseCreate_IngressBlockSendsErrorEventAndSkipsUpstream is the
+// integration flavour of TestWSResponseCreate_BlockReturnsTypedError. It
+// asserts that with a custom block rule, the client receives a Realtime-style
+// error event AND the upstream FrameConn never receives the offending frame.
+func TestWSResponseCreate_IngressBlockSendsErrorEventAndSkipsUpstream(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ // No events queued; the upstream should never get written to anyway.
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ blockSettings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionBlock,
+ Scope: BetaPolicyScopeAll,
+ ErrorMessage: "ws priority blocked for testing",
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
+ raw, err := json.Marshal(blockSettings)
+ require.NoError(t, err)
+ repo.values[SettingKeyOpenAIFastPolicySettings] = string(raw)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ settingService: NewSettingService(repo, cfg),
+ }
+
+ account := &Account{
+ ID: 902,
+ Name: "openai-ws-block",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{"api_key": "sk-test"},
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() { _ = conn.CloseNow() }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ _, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ proxyErr := svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ // Mirror the production handler (openai_gateway_handler.go:1325-1328):
+ // when the proxy returns an OpenAIWSClientCloseError, surface its
+ // status code to the client via a graceful close handshake. Without
+ // this the deferred CloseNow() above would tear down the TCP
+ // connection without sending a close frame, and the C3 timing
+ // assertion (next read returns CloseStatus=1008) would see EOF
+ // instead.
+ var closeErr *OpenAIWSClientCloseError
+ if errors.As(proxyErr, &closeErr) {
+ _ = conn.Close(closeErr.StatusCode(), closeErr.Reason())
+ }
+ serverErrCh <- proxyErr
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() { _ = clientConn.CloseNow() }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.5","stream":false,"service_tier":"priority"}`)))
+ cancelWrite()
+
+ // C3 timing assertion: the FIRST frame the client reads must be the
+ // error event — not a close frame. coder/websocket@v1.8.14 Conn.Write is
+ // synchronous (writeFrame Flushes the bufio writer at write.go:307-311
+ // before returning) and the close handshake re-acquires the same
+ // writeFrameMu, so this ordering is enforced by the library itself; this
+ // assertion guards against future refactors that might break it.
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, event, readErr := clientConn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, readErr, "first read must succeed and return the error event before any close frame")
+ require.Equal(t, "error", gjson.GetBytes(event, "type").String())
+ require.Equal(t, "invalid_request_error", gjson.GetBytes(event, "error.type").String())
+ // B1 regression: event_id + error.code must be populated.
+ require.Equal(t, "policy_violation", gjson.GetBytes(event, "error.code").String())
+ require.NotEmpty(t, gjson.GetBytes(event, "event_id").String(), "event_id must be present so clients can correlate")
+ require.Contains(t, gjson.GetBytes(event, "error.message").String(), "ws priority blocked for testing")
+
+ // Next read must surface the close frame (as a CloseError). This
+ // asserts the [error event, close] ordering — i.e. the close did NOT
+ // race ahead of the data frame.
+ readCtx2, cancelRead2 := context.WithTimeout(context.Background(), 3*time.Second)
+ _, _, secondReadErr := clientConn.Read(readCtx2)
+ cancelRead2()
+ require.Error(t, secondReadErr, "after the error event the connection must surface a close")
+ require.Equal(t, coderws.StatusPolicyViolation, coderws.CloseStatus(secondReadErr),
+ "close status must be PolicyViolation; got %v", secondReadErr)
+
+ select {
+ case serverErr := <-serverErrCh:
+ // Server returns an OpenAIWSClientCloseError — handler closes the WS;
+ // here we just assert it surfaced as the typed close error.
+ require.Error(t, serverErr)
+ var closeErr *OpenAIWSClientCloseError
+ require.True(t, errors.As(serverErr, &closeErr), "block 应返回 OpenAIWSClientCloseError,得到 %T: %v", serverErr, serverErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress 关闭超时")
+ }
+
+ // Critical: the offending frame must NEVER reach the upstream.
+ // captureDialer.DialCount may legitimately be 0 or 1 depending on whether
+ // the lease was acquired before policy fired; either way, no writes.
+ require.Empty(t, captureConn.writes, "block 命中后上游不应收到 response.create")
+}
+
+// --- HTTP-side gap-filling tests (already covered by existing tests but
+// requested to be split out explicitly) ---
+
+// TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream confirms that
+// applyOpenAIFastPolicyToBody surfaces a *OpenAIFastBlockedError when the rule
+// action is "block", and that the body is left untouched. The caller (chat
+// completions / messages handlers) inspects this typed error and skips the
+// upstream HTTP call entirely — see openai_gateway_chat_completions.go:175 and
+// openai_gateway_messages.go:149.
+func TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream(t *testing.T) {
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionBlock,
+ Scope: BetaPolicyScopeAll,
+ ErrorMessage: "priority blocked",
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ body := []byte(`{"model":"gpt-5.5","service_tier":"priority","input":[]}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.Error(t, err)
+ var blocked *OpenAIFastBlockedError
+ require.True(t, errors.As(err, &blocked), "block must surface as typed error so caller can skip upstream HTTP request")
+ require.Equal(t, "priority blocked", blocked.Message)
+ require.Equal(t, string(body), string(updated), "block must not mutate body")
+}
+
+// TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy verifies
+// the Anthropic-compat entrypoint chain: anthropic-beta: fast-mode → BetaFastMode
+// detection → ServiceTier="priority" injection (openai_gateway_messages.go:60)
+// → applyOpenAIFastPolicyToBody filter on default policy → upstream body has
+// no service_tier. We exercise the same internal pipeline (Anthropic→Responses
+// + BetaFastMode + policy) without spinning up a real upstream HTTP server.
+func TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Step 1: parse Anthropic request (mirrors openai_gateway_messages.go:38-50).
+ anthropicBody := []byte(`{"model":"gpt-5.5","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`)
+ var anthropicReq apicompat.AnthropicRequest
+ require.NoError(t, json.Unmarshal(anthropicBody, &anthropicReq))
+ responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
+ require.NoError(t, err)
+
+ // Step 2: BetaFastMode header → service_tier="priority" (mirrors line 58-61).
+ headers := http.Header{}
+ headers.Set("anthropic-beta", claude.BetaFastMode)
+ require.True(t, containsBetaToken(headers.Get("anthropic-beta"), claude.BetaFastMode))
+ responsesReq.ServiceTier = "priority"
+ responsesReq.Model = "gpt-5.5"
+
+ // Step 3: marshal & apply fast policy (mirrors line 78 + 149).
+ responsesBody, err := json.Marshal(responsesReq)
+ require.NoError(t, err)
+ require.Equal(t, "priority", gjson.GetBytes(responsesBody, "service_tier").String(), "前置:beta 翻译应当注入 priority")
+
+ upstreamBody, policyErr := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", responsesBody)
+ require.NoError(t, policyErr)
+
+ // Step 4: assert that policy filtered the field before the upstream HTTP request.
+ require.NotContains(t, string(upstreamBody), `"service_tier"`, "default policy 命中 gpt-5.5 priority 应当 filter 掉 service_tier")
+}
+
+// --- Fix1: passthrough capturedSessionModel must follow session.update ---
+
+// TestPolicyEnforcingFrameConn_SessionUpdateRotatesCapturedModel covers the
+// fix1 bypass: client opens with a whitelist-miss model (gpt-4o → pass under
+// gpt-5.5 whitelist), rotates to gpt-5.5 via session.update, then sends
+// response.create without "model". Without the session.update sniffing the
+// follow-up frame would fall back to the stale gpt-4o capture and pass — the
+// fix updates capturedSessionModel from session.* events so the fallback now
+// resolves to gpt-5.5 and the policy filters service_tier.
+func TestPolicyEnforcingFrameConn_SessionUpdateRotatesCapturedModel(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Frame 1: response.create with whitelist-miss model — under default
+ // rule fallback=pass, service_tier stays.
+ first := []byte(`{"type":"response.create","model":"gpt-4o","service_tier":"priority"}`)
+ // Frame 2: session.update rotates the session model to gpt-5.5.
+ rotate := []byte(`{"type":"session.update","session":{"model":"gpt-5.5"}}`)
+ // Frame 3: response.create WITHOUT model — must inherit gpt-5.5.
+ followup := []byte(`{"type":"response.create","service_tier":"priority"}`)
+
+ inner := &fakePassthroughFrameConn{reads: [][]byte{first, rotate, followup}}
+
+ // Replicate the production wiring in openai_ws_v2_passthrough_adapter.go
+ // so capturedSessionModel state is shared across frames.
+ capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, first)
+ require.Equal(t, "gpt-4o", capturedSessionModel)
+ wrapper := &openAIWSPolicyEnforcingFrameConn{
+ inner: inner,
+ filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
+ if msgType != coderws.MessageText {
+ return payload, nil, nil
+ }
+ if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
+ capturedSessionModel = updated
+ }
+ model := openAIWSPassthroughPolicyModelForFrame(account, payload)
+ if model == "" {
+ model = capturedSessionModel
+ }
+ return svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
+ },
+ }
+
+ // Frame 1: gpt-4o miss whitelist → pass (service_tier preserved).
+ _, payload1, err := wrapper.ReadFrame(context.Background())
+ require.NoError(t, err)
+ require.Contains(t, string(payload1), `"service_tier"`, "frame1: gpt-4o miss whitelist → pass keeps service_tier")
+
+ // Frame 2: session.update — not response.create, untouched, but its
+ // side effect updates capturedSessionModel to gpt-5.5.
+ _, payload2, err := wrapper.ReadFrame(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, string(rotate), string(payload2), "session.update frame is forwarded verbatim")
+ require.Equal(t, "gpt-5.5", capturedSessionModel, "fix1: session.update must rotate capturedSessionModel")
+
+ // Frame 3: empty model + new captured gpt-5.5 → matches whitelist → filter.
+ _, payload3, err := wrapper.ReadFrame(context.Background())
+ require.NoError(t, err)
+ require.NotContains(t, string(payload3), `"service_tier"`,
+ "fix1: post-rotate response.create without model must use refreshed capturedSessionModel and trigger filter")
+}
+
+// TestPolicyModelFromSessionFrame_OnlySessionUpdate covers the negative
+// branches of openAIWSPassthroughPolicyModelFromSessionFrame: only
+// client→upstream session.update frames rotate the captured model;
+// server→client events (session.created) and unrelated frames must not.
+func TestPolicyModelFromSessionFrame_OnlySessionUpdate(t *testing.T) {
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // session.created is a server→client event in the OpenAI Realtime
+ // protocol — clients never send it, so this filter (which only runs on
+ // the client→upstream direction) must ignore it even if it appears.
+ created := []byte(`{"type":"session.created","session":{"model":"gpt-5.5"}}`)
+ require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, created))
+
+ // Non-session.* frames must NOT trigger rotation.
+ notSession := []byte(`{"type":"response.create","session":{"model":"gpt-9"}}`)
+ require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, notSession))
+
+ // Missing session.model returns empty — caller keeps the old captured value.
+ noModel := []byte(`{"type":"session.update","session":{"voice":"alloy"}}`)
+ require.Empty(t, openAIWSPassthroughPolicyModelFromSessionFrame(account, noModel))
+}
+
+// --- Fix2: native /responses normalize "fast" → "priority" on pass ---
+
+// TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias is the fix2
+// regression. Before the fix, when action=pass, applyOpenAIFastPolicyToBody
+// returned the body unchanged so a raw "fast" alias would leak to the
+// upstream OpenAI API (which does not accept "fast"). The fix normalizes
+// "fast" → "priority" on pass too.
+func TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias(t *testing.T) {
+ // Use a policy that deliberately misses gpt-4 so the action is pass.
+ settings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, settings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // gpt-4 + "fast" → fallback pass. Body must be rewritten to "priority".
+ body := []byte(`{"model":"gpt-4","service_tier":"fast"}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
+ require.NoError(t, err)
+ require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String(),
+ "fix2: pass action must still normalize 'fast' → 'priority' so upstream OpenAI accepts the slug")
+
+ // Already-canonical "priority" on pass: zero mutation (byte-equal).
+ body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
+ require.NoError(t, err)
+ require.Equal(t, string(body), string(updated))
+
+ // Mixed-case alias → normalized.
+ body = []byte(`{"model":"gpt-4","service_tier":" Fast "}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
+ require.NoError(t, err)
+ require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String())
+
+ // Unrecognized tier → still no-op (not normalized, since normTier == "").
+ body = []byte(`{"model":"gpt-4","service_tier":"turbo"}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
+ require.NoError(t, err)
+ require.Equal(t, string(body), string(updated))
+}
+
+// --- Fix3: passthrough billing must reflect post-filter service_tier ---
+
+// TestPassthroughBilling_PostFilterServiceTier is the fix3 regression. The
+// passthrough adapter (openai_ws_v2_passthrough_adapter.go) now extracts
+// requestServiceTier from firstClientMessage AFTER applyOpenAIFastPolicy
+// has rewritten it, so a filter hit causes billing to report nil (default
+// tier) instead of the user-requested "priority". This test pins the
+// contract those two helpers must uphold for the adapter's billing path.
+func TestPassthroughBilling_PostFilterServiceTier(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ raw := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
+
+ // Pre-filter sanity: extracting from the raw frame would (incorrectly,
+ // pre-fix) report "priority" — this is the very thing the adapter
+ // must NOT do anymore.
+ pre := extractOpenAIServiceTierFromBody(raw)
+ require.NotNil(t, pre)
+ require.Equal(t, "priority", *pre,
+ "sanity: raw first frame carries priority that pre-fix billing would have reported")
+
+ // Apply policy filter (default rule: gpt-5.5 + priority → filter).
+ filtered, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", raw)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.NotContains(t, string(filtered), `"service_tier"`)
+
+ // Post-filter: extracting from the rewritten frame returns nil. This
+ // is the value the adapter now passes to OpenAIForwardResult.ServiceTier,
+ // so billing records "default" instead of "priority".
+ post := extractOpenAIServiceTierFromBody(filtered)
+ require.Nil(t, post, "fix3: post-filter extraction must return nil so passthrough billing reports default tier instead of the requested priority")
+
+ // And the byte-level invariant the adapter relies on: filtering an
+ // already-filtered frame is a no-op (idempotent), so re-running the
+ // policy doesn't accidentally re-introduce the field.
+ again, blocked2, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", filtered)
+ require.NoError(t, err)
+ require.Nil(t, blocked2)
+ require.Equal(t, string(filtered), string(again),
+ "policy is idempotent: filtering an already-filtered frame leaves bytes unchanged")
+}
+
+// TestApplyOpenAIFastPolicyToBody_NonStringServiceTier covers the test gap
+// flagged in the review: when a client sends service_tier as a non-string
+// (number, null, object, etc.) the policy must NOT panic and must NOT
+// pretend the field was filtered. Behavior: skip policy entirely (treat as
+// "no usable tier"), forward body unchanged. This mirrors the HTTP entry's
+// type-assertion `reqBody["service_tier"].(string); ok` guard.
+func TestApplyOpenAIFastPolicyToBody_NonStringServiceTier(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Number — gjson .String() coerces to "1" which is not a recognized
+ // tier alias; normalize returns "" → policy no-ops.
+ cases := [][]byte{
+ []byte(`{"model":"gpt-5.5","service_tier":1}`),
+ []byte(`{"model":"gpt-5.5","service_tier":null}`),
+ []byte(`{"model":"gpt-5.5","service_tier":{"nested":"priority"}}`),
+ []byte(`{"model":"gpt-5.5","service_tier":["priority"]}`),
+ []byte(`{"model":"gpt-5.5","service_tier":true}`),
+ }
+ for _, body := range cases {
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err, "non-string service_tier must not error: %s", string(body))
+ require.Equal(t, string(body), string(updated),
+ "non-string service_tier must pass through unchanged: %s", string(body))
+ }
+
+ // Same guard for the WS response.create entry.
+ for _, body := range cases {
+ frame := body
+ updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err, "non-string service_tier ws frame must not error: %s", string(frame))
+ require.Nil(t, blocked, "non-string service_tier must not trigger block: %s", string(frame))
+ require.Equal(t, string(frame), string(updated),
+ "non-string service_tier ws frame must pass through unchanged: %s", string(frame))
+ }
+}
+
+// TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames covers the
+// multi-turn passthrough billing regression: OpenAI Realtime / Responses WS
+// allows the client to ship a different service_tier on each response.create
+// frame (per-response field, see codex-rs/core/src/client.rs
+// build_responses_request which re-fills the field on every request). Before
+// the fix the adapter only captured service_tier from firstClientMessage so
+// turn 2/3 billing was wrong. After the fix the filter closure refreshes an
+// atomic.Pointer[string] on every successful response.create frame.
+//
+// This test pins the four legs of the semantic contract:
+// - turn 1: service_tier=priority hits the default whitelist filter, so
+// after filter the upstream sees no tier → billing is nil.
+// - turn 2: service_tier=flex passes (default rule targets priority only),
+// billing should now reflect "flex".
+// - turn 3: response.create without any service_tier — the upstream will
+// treat it as default; we choose to mirror that and overwrite billing
+// to nil rather than carry over "flex" from turn 2.
+// - non-response.create frame (response.cancel here) carrying a stray
+// service_tier-shaped field must NOT clobber the billing pointer.
+func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ // Mirror the production filter closure (openai_ws_v2_passthrough_adapter.go
+ // proxyResponsesWebSocketV2Passthrough) so this test fails if the
+ // production code drops the per-frame Store.
+ var requestServiceTierPtr atomic.Pointer[string]
+ capturedSessionModel := ""
+ filter := func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
+ if msgType != coderws.MessageText {
+ return payload, nil, nil
+ }
+ if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
+ capturedSessionModel = updated
+ }
+ model := openAIWSPassthroughPolicyModelForFrame(account, payload)
+ if model == "" {
+ model = capturedSessionModel
+ }
+ out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, model, payload)
+ if policyErr == nil && blocked == nil &&
+ strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
+ requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
+ }
+ return out, blocked, policyErr
+ }
+
+ // First-frame initialization mirrors the adapter: extract from the
+ // post-filter payload so a filter-on-first-frame zeroes billing too.
+ firstFrame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
+ firstOut, firstBlocked, firstErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", firstFrame)
+ require.NoError(t, firstErr)
+ require.Nil(t, firstBlocked)
+ requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstOut))
+ capturedSessionModel = openAIWSPassthroughPolicyModelForFrame(account, firstFrame)
+ require.Nil(t, requestServiceTierPtr.Load(),
+ "turn 1: filter strips service_tier=priority, billing must reflect upstream-actual nil tier")
+
+ // Turn 2: client switches to flex, should pass and update billing.
+ turn2 := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`)
+ out2, blocked2, err2 := filter(coderws.MessageText, turn2)
+ require.NoError(t, err2)
+ require.Nil(t, blocked2)
+ require.Equal(t, "flex", gjson.GetBytes(out2, "service_tier").String(), "turn 2: flex must pass to upstream untouched")
+ tier2 := requestServiceTierPtr.Load()
+ require.NotNil(t, tier2, "turn 2: billing must update to reflect flex")
+ require.Equal(t, "flex", *tier2)
+
+ // A non-response.create frame with a stray service_tier-shaped field
+ // must NOT overwrite the billing pointer (those frames don't carry
+ // per-response service_tier in the Realtime spec).
+ cancelFrame := []byte(`{"type":"response.cancel","service_tier":"priority"}`)
+ _, blockedCancel, errCancel := filter(coderws.MessageText, cancelFrame)
+ require.NoError(t, errCancel)
+ require.Nil(t, blockedCancel)
+ tierAfterCancel := requestServiceTierPtr.Load()
+ require.NotNil(t, tierAfterCancel, "response.cancel must not clobber billing tier to nil")
+ require.Equal(t, "flex", *tierAfterCancel,
+ "non-response.create frames must not update billing tier even if they carry a service_tier-shaped field")
+
+ // Turn 3: response.create without any service_tier. We deliberately
+ // overwrite billing back to nil so it tracks what the upstream actually
+ // sees on this turn (default tier).
+ turn3 := []byte(`{"type":"response.create","model":"gpt-5.5"}`)
+ out3, blocked3, err3 := filter(coderws.MessageText, turn3)
+ require.NoError(t, err3)
+ require.Nil(t, blocked3)
+ require.Equal(t, string(turn3), string(out3), "turn 3 has no service_tier — filter must not mutate")
+ require.Nil(t, requestServiceTierPtr.Load(),
+ "turn 3: response.create without service_tier overwrites billing to nil to match upstream default")
+}
+
+// TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier locks in the
+// "block keeps previous" semantic: when policy returns block on a
+// response.create frame, that frame is never sent upstream, so billing tier
+// must keep the previous turn's value rather than getting silently zeroed.
+func TestPassthroughBilling_BlockedFrameDoesNotMutateServiceTier(t *testing.T) {
+ blockSettings := &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionBlock,
+ Scope: BetaPolicyScopeAll,
+ ErrorMessage: "blocked",
+ ModelWhitelist: []string{"gpt-5.5"},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+ svc := newOpenAIGatewayServiceWithSettings(t, blockSettings)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ var requestServiceTierPtr atomic.Pointer[string]
+ flexValue := "flex"
+ requestServiceTierPtr.Store(&flexValue) // simulate prior turn billed as flex
+
+ filter := func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
+ if msgType != coderws.MessageText {
+ return payload, nil, nil
+ }
+ out, blocked, policyErr := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", payload)
+ if policyErr == nil && blocked == nil &&
+ strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
+ requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
+ }
+ return out, blocked, policyErr
+ }
+
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
+ _, blocked, err := filter(coderws.MessageText, frame)
+ require.NoError(t, err)
+ require.NotNil(t, blocked, "policy must block this frame")
+
+ tier := requestServiceTierPtr.Load()
+ require.NotNil(t, tier, "blocked frame must not clobber prior billing tier to nil")
+ require.Equal(t, "flex", *tier,
+ "blocked frame is never sent upstream; billing must retain the previous turn's tier")
+}
diff --git a/backend/internal/service/openai_gateway_403_reset_test.go b/backend/internal/service/openai_gateway_403_reset_test.go
new file mode 100644
index 00000000..c6805464
--- /dev/null
+++ b/backend/internal/service/openai_gateway_403_reset_test.go
@@ -0,0 +1,39 @@
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type openAI403CounterResetStub struct {
+ resetCalls []int64
+}
+
+func (s *openAI403CounterResetStub) IncrementOpenAI403Count(context.Context, int64, int) (int64, error) {
+ return 0, nil
+}
+
+func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
+ s.resetCalls = append(s.resetCalls, accountID)
+ return nil
+}
+
+func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) {
+ counter := &openAI403CounterResetStub{}
+ rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
+ rateLimitSvc.SetOpenAI403CounterCache(counter)
+
+ svc := &OpenAIGatewayService{
+ rateLimitService: rateLimitSvc,
+ }
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{},
+ Account: &Account{ID: 777, Platform: PlatformOpenAI},
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, []int64{777}, counter.resetCalls)
+}
diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go
index ac7d28a7..5822ae4c 100644
--- a/backend/internal/service/openai_gateway_chat_completions.go
+++ b/backend/internal/service/openai_gateway_chat_completions.go
@@ -107,11 +107,15 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
responsesBody = stripped
}
}
+ responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody)
+ if err != nil {
+ return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err)
+ }
// Minimal stub populated from the raw body so downstream billing
// propagation (ServiceTier, ReasoningEffort) keeps working.
responsesReq = &apicompat.ResponsesRequest{
Model: upstreamModel,
- ServiceTier: gjson.GetBytes(responsesBody, "service_tier").String(),
+ ServiceTier: normalizedServiceTier,
}
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
@@ -124,6 +128,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
responsesReq.Model = upstreamModel
+ normalizeResponsesRequestServiceTier(responsesReq)
responsesBody, err = json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
@@ -166,6 +171,17 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
}
}
+ // 4b. Apply OpenAI fast policy (may filter service_tier or block the request).
+ updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
+ if policyErr != nil {
+ var blocked *OpenAIFastBlockedError
+ if errors.As(policyErr, &blocked) {
+ writeChatCompletionsError(c, http.StatusForbidden, "permission_error", blocked.Message)
+ }
+ return nil, policyErr
+ }
+ responsesBody = updatedBody
+
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
@@ -274,6 +290,41 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return result, handleErr
}
+func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) {
+ if req == nil {
+ return
+ }
+ req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier)
+}
+
+func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) {
+ if len(body) == 0 {
+ return body, "", nil
+ }
+ rawServiceTier := gjson.GetBytes(body, "service_tier").String()
+ if rawServiceTier == "" {
+ return body, "", nil
+ }
+ normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier)
+ if normalizedServiceTier == "" {
+ trimmed, err := sjson.DeleteBytes(body, "service_tier")
+ return trimmed, "", err
+ }
+ if normalizedServiceTier == rawServiceTier {
+ return body, normalizedServiceTier, nil
+ }
+ trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier)
+ return trimmed, normalizedServiceTier, err
+}
+
+func normalizedOpenAIServiceTierValue(raw string) string {
+ normalized := normalizeOpenAIServiceTier(raw)
+ if normalized == nil {
+ return ""
+ }
+ return *normalized
+}
+
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go
new file mode 100644
index 00000000..6846e03a
--- /dev/null
+++ b/backend/internal/service/openai_gateway_chat_completions_test.go
@@ -0,0 +1,75 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
+ t.Parallel()
+
+ req := &apicompat.ResponsesRequest{ServiceTier: " fast "}
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "priority", req.ServiceTier)
+
+ req.ServiceTier = "flex"
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "flex", req.ServiceTier)
+
+ // OpenAI 官方合法 tier 应被透传保留。
+ req.ServiceTier = "auto"
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "auto", req.ServiceTier)
+
+ req.ServiceTier = "default"
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "default", req.ServiceTier)
+
+ req.ServiceTier = "scale"
+ normalizeResponsesRequestServiceTier(req)
+ require.Equal(t, "scale", req.ServiceTier)
+
+ // 真未知值仍被剥离。
+ req.ServiceTier = "turbo"
+ normalizeResponsesRequestServiceTier(req)
+ require.Empty(t, req.ServiceTier)
+}
+
+func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
+ t.Parallel()
+
+ body, tier, err := normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"fast"}`))
+ require.NoError(t, err)
+ require.Equal(t, "priority", tier)
+ require.Equal(t, "priority", gjson.GetBytes(body, "service_tier").String())
+
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"flex"}`))
+ require.NoError(t, err)
+ require.Equal(t, "flex", tier)
+ require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
+
+ // OpenAI 官方 tier 直接保留在 body 中(透传上游)。
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"auto"}`))
+ require.NoError(t, err)
+ require.Equal(t, "auto", tier)
+ require.Equal(t, "auto", gjson.GetBytes(body, "service_tier").String())
+
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
+ require.NoError(t, err)
+ require.Equal(t, "default", tier)
+ require.Equal(t, "default", gjson.GetBytes(body, "service_tier").String())
+
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"scale"}`))
+ require.NoError(t, err)
+ require.Equal(t, "scale", tier)
+ require.Equal(t, "scale", gjson.GetBytes(body, "service_tier").String())
+
+ // 真未知值才会被删除。
+ body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"turbo"}`))
+ require.NoError(t, err)
+ require.Empty(t, tier)
+ require.False(t, gjson.GetBytes(body, "service_tier").Exists())
+}
diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go
index 2a0a72eb..4e0ebb2e 100644
--- a/backend/internal/service/openai_gateway_messages.go
+++ b/backend/internal/service/openai_gateway_messages.go
@@ -143,6 +143,19 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
}
}
+ // 4c. Apply OpenAI fast policy (may filter service_tier or block the request).
+ // Mirrors the Claude anthropic-beta "fast-mode-2026-02-01" filter, but keyed
+ // on the body-level service_tier field (priority/flex).
+ updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, upstreamModel, responsesBody)
+ if policyErr != nil {
+ var blocked *OpenAIFastBlockedError
+ if errors.As(policyErr, &blocked) {
+ writeAnthropicError(c, http.StatusForbidden, "forbidden_error", blocked.Message)
+ }
+ return nil, policyErr
+ }
+ responsesBody = updatedBody
+
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go
index e6fa94aa..47ff4e3b 100644
--- a/backend/internal/service/openai_gateway_record_usage_test.go
+++ b/backend/internal/service/openai_gateway_record_usage_test.go
@@ -148,6 +148,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
nil,
nil,
+ nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
@@ -826,18 +827,29 @@ func TestNormalizeOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *got)
})
- t.Run("default ignored", func(t *testing.T) {
- require.Nil(t, normalizeOpenAIServiceTier("default"))
+ t.Run("openai official tiers preserved", func(t *testing.T) {
+ // OpenAI 官方文档定义的合法 tier 值都应被透传保留,避免因白名单过窄
+ // 静默剥离客户端显式发送的合法字段。Codex 客户端只发 priority/flex,
+ // 所以扩大白名单对 Codex 流量零影响(见 codex-rs/core/src/client.rs)。
+ for _, tier := range []string{"priority", "flex", "auto", "default", "scale"} {
+ got := normalizeOpenAIServiceTier(tier)
+ require.NotNil(t, got, "tier %q should not be normalized to nil", tier)
+ require.Equal(t, tier, *got)
+ }
})
t.Run("invalid ignored", func(t *testing.T) {
require.Nil(t, normalizeOpenAIServiceTier("turbo"))
+ require.Nil(t, normalizeOpenAIServiceTier("xxx"))
})
}
func TestExtractOpenAIServiceTier(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"}))
require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"}))
+ require.Equal(t, "auto", *extractOpenAIServiceTier(map[string]any{"service_tier": "auto"}))
+ require.Equal(t, "default", *extractOpenAIServiceTier(map[string]any{"service_tier": "default"}))
+ require.Equal(t, "scale", *extractOpenAIServiceTier(map[string]any{"service_tier": "scale"}))
require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1}))
require.Nil(t, extractOpenAIServiceTier(nil))
}
@@ -845,7 +857,10 @@ func TestExtractOpenAIServiceTier(t *testing.T) {
func TestExtractOpenAIServiceTierFromBody(t *testing.T) {
require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`)))
require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`)))
- require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
+ require.Equal(t, "auto", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"auto"}`)))
+ require.Equal(t, "default", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`)))
+ require.Equal(t, "scale", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"scale"}`)))
+ require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"turbo"}`)))
require.Nil(t, extractOpenAIServiceTierFromBody(nil))
}
@@ -1031,7 +1046,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel
Model: "gpt-5.1",
Duration: time.Second,
},
- APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}},
+ APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}},
User: &User{ID: 200},
Account: &Account{ID: 300},
Subscription: subscription,
@@ -1070,3 +1085,78 @@ func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *t
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
+
+func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_image_only_usage",
+ Model: "gpt-image-2",
+ ImageCount: 2,
+ ImageSize: "1K",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{ID: 1007},
+ User: &User{ID: 2007},
+ Account: &Account{ID: 3007},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.Equal(t, 2, usageRepo.lastLog.ImageCount)
+ require.NotNil(t, usageRepo.lastLog.ImageSize)
+ require.Equal(t, "1K", *usageRepo.lastLog.ImageSize)
+ require.NotNil(t, usageRepo.lastLog.BillingMode)
+ require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTokens(t *testing.T) {
+ imagePrice := 0.02
+ groupID := int64(12)
+
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_image_per_request",
+ Model: "gpt-image-2",
+ Usage: OpenAIUsage{
+ InputTokens: 1110,
+ OutputTokens: 1756,
+ ImageOutputTokens: 1756,
+ },
+ ImageCount: 2,
+ ImageSize: "1K",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1008,
+ GroupID: i64p(groupID),
+ Group: &Group{
+ ID: groupID,
+ RateMultiplier: 1.0,
+ ImagePrice1K: &imagePrice,
+ },
+ },
+ User: &User{ID: 2008},
+ Account: &Account{ID: 3008},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.NotNil(t, usageRepo.lastLog.BillingMode)
+ require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
+ require.Equal(t, 2, usageRepo.lastLog.ImageCount)
+ require.InDelta(t, 0.04, usageRepo.lastLog.TotalCost, 1e-12)
+ require.InDelta(t, 0.04, usageRepo.lastLog.ActualCost, 1e-12)
+ require.InDelta(t, 0.0, usageRepo.lastLog.InputCost, 1e-12)
+ require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12)
+ require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12)
+}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 064191bd..ed69730c 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -40,7 +40,7 @@ const (
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiStickySessionTTL = time.Hour // 粘性会话TTL
- codexCLIUserAgent = "codex_cli_rs/0.104.0"
+ codexCLIUserAgent = "codex_cli_rs/0.125.0"
// codex_cli_only 拒绝时单个请求头日志长度上限(字符)
codexCLIOnlyHeaderValueMaxBytes = 256
@@ -54,7 +54,7 @@ const (
openAIWSRetryBackoffMaxDefault = 2 * time.Second
openAIWSRetryJitterRatioDefault = 0.2
openAICompactSessionSeedKey = "openai_compact_session_seed"
- codexCLIVersion = "0.104.0"
+ codexCLIVersion = "0.125.0"
// Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。
openAICodexSnapshotPersistMinInterval = 30 * time.Second
)
@@ -233,6 +233,8 @@ type OpenAIForwardResult struct {
ResponseHeaders http.Header
Duration time.Duration
FirstTokenMs *int
+ ImageCount int
+ ImageSize string
}
type OpenAIWSRetryMetricsSnapshot struct {
@@ -304,6 +306,10 @@ func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool {
var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval)
+// ErrNoAvailableCompactAccounts indicates the request needs /responses/compact
+// support but no compatible account is available.
+var ErrNoAvailableCompactAccounts = errors.New("no available OpenAI accounts support /responses/compact")
+
// OpenAIGatewayService handles OpenAI API gateway operations
type OpenAIGatewayService struct {
accountRepo AccountRepository
@@ -328,6 +334,7 @@ type OpenAIGatewayService struct {
resolver *ModelPricingResolver
channelService *ChannelService
balanceNotifyService *BalanceNotifyService
+ settingService *SettingService
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
@@ -366,6 +373,7 @@ func NewOpenAIGatewayService(
resolver *ModelPricingResolver,
channelService *ChannelService,
balanceNotifyService *BalanceNotifyService,
+ settingService *SettingService,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
@@ -396,6 +404,7 @@ func NewOpenAIGatewayService(
resolver: resolver,
channelService: channelService,
balanceNotifyService: balanceNotifyService,
+ settingService: settingService,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
@@ -440,11 +449,11 @@ func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Contex
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
}
-func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
+func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string, requireCompact bool) bool {
if s.channelService == nil {
return false
}
- upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "")
+ upstreamModel := resolveOpenAIAccountUpstreamModelForRequest(account, requestedModel, requireCompact)
if upstreamModel == "" {
return false
}
@@ -1119,6 +1128,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str
return sessionID
}
+func explicitOpenAISessionID(c *gin.Context, body []byte) string {
+ if c == nil {
+ return ""
+ }
+
+ sessionID := strings.TrimSpace(c.GetHeader("session_id"))
+ if sessionID == "" {
+ sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
+ }
+ if sessionID == "" && len(body) > 0 {
+ sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
+ }
+ return sessionID
+}
+
+// GenerateExplicitSessionHash generates a sticky-session hash only from explicit
+// client session signals. It intentionally skips content-derived fallback and is
+// used by stateless endpoints such as /v1/images.
+func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string {
+ sessionID := explicitOpenAISessionID(c, body)
+ if sessionID == "" {
+ return ""
+ }
+
+ currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
+ attachOpenAILegacySessionHashToGin(c, legacyHash)
+ return currentHash
+}
+
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
//
// Priority:
@@ -1131,13 +1169,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
return ""
}
- sessionID := strings.TrimSpace(c.GetHeader("session_id"))
- if sessionID == "" {
- sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
- }
- if sessionID == "" && len(body) > 0 {
- sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
- }
+ sessionID := explicitOpenAISessionID(c, body)
if sessionID == "" && len(body) > 0 {
sessionID = deriveOpenAIContentSessionSeed(body)
}
@@ -1206,10 +1238,94 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
- return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, 0)
+ return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, false, 0)
}
-func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) {
+// noAvailableOpenAISelectionError builds the standard "no account available" error
+// while preserving the compact-specific error when applicable.
+func noAvailableOpenAISelectionError(requestedModel string, compactBlocked bool) error {
+ if compactBlocked {
+ return ErrNoAvailableCompactAccounts
+ }
+ if requestedModel != "" {
+ return fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel)
+ }
+ return errors.New("no available OpenAI accounts")
+}
+
+// openAICompactSupportTier classifies an OpenAI account by compact capability.
+// 0 = explicitly unsupported, 1 = unknown / not yet probed, 2 = explicitly supported.
+func openAICompactSupportTier(account *Account) int {
+ if account == nil || !account.IsOpenAI() {
+ return 0
+ }
+ supported, known := account.OpenAICompactSupportKnown()
+ if !known {
+ return 1
+ }
+ if supported {
+ return 2
+ }
+ return 0
+}
+
+// isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model /
+// compact-support checks used during account selection.
+func isOpenAIAccountEligibleForRequest(account *Account, requestedModel string, requireCompact bool) bool {
+ if account == nil || !account.IsSchedulable() || !account.IsOpenAI() {
+ return false
+ }
+ if requestedModel != "" && !account.IsModelSupported(requestedModel) {
+ return false
+ }
+ if requireCompact && openAICompactSupportTier(account) == 0 {
+ return false
+ }
+ return true
+}
+
+// prioritizeOpenAICompactAccounts re-orders a slice so that accounts with known
+// compact support are tried first, followed by unknown, then explicitly unsupported.
+// The relative order within each tier is preserved.
+func prioritizeOpenAICompactAccounts(accounts []*Account) []*Account {
+ if len(accounts) == 0 {
+ return nil
+ }
+ supported := make([]*Account, 0, len(accounts))
+ unknown := make([]*Account, 0, len(accounts))
+ unsupported := make([]*Account, 0, len(accounts))
+ for _, account := range accounts {
+ switch openAICompactSupportTier(account) {
+ case 2:
+ supported = append(supported, account)
+ case 1:
+ unknown = append(unknown, account)
+ default:
+ unsupported = append(unsupported, account)
+ }
+ }
+ out := make([]*Account, 0, len(accounts))
+ out = append(out, supported...)
+ out = append(out, unknown...)
+ out = append(out, unsupported...)
+ return out
+}
+
+// resolveOpenAIAccountUpstreamModelForRequest resolves the upstream model that
+// would be sent for a given request, honouring compact-only mappings when the
+// caller is on the /responses/compact path.
+func resolveOpenAIAccountUpstreamModelForRequest(account *Account, requestedModel string, requireCompact bool) string {
+ upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "")
+ if upstreamModel == "" {
+ return ""
+ }
+ if requireCompact {
+ return resolveOpenAICompactForwardModel(account, upstreamModel)
+ }
+ return upstreamModel
+}
+
+func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) (*Account, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
@@ -1219,7 +1335,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
// 1. 尝试粘性会话命中
// Try sticky session hit
- if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil {
+ if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID); account != nil {
return account, nil
}
@@ -1232,13 +1348,10 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
// 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU
- selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs)
+ selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact)
if selected == nil {
- if requestedModel != "" {
- return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel)
- }
- return nil, errors.New("no available OpenAI accounts")
+ return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked)
}
// 4. 设置粘性会话绑定
@@ -1255,7 +1368,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
-func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) *Account {
+func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) *Account {
if sessionHash == "" {
return nil
}
@@ -1287,19 +1400,16 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// 验证账号是否可用于当前请求
// Verify account is usable for current request
- if !account.IsSchedulable() || !account.IsOpenAI() {
+ if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
return nil
}
- if requestedModel != "" && !account.IsModelSupported(requestedModel) {
- return nil
- }
- account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
+ account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) &&
- s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
+ s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
@@ -1314,9 +1424,13 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// 返回 nil 表示无可用账号。
//
// selectBestAccount selects the best account from candidates (priority + LRU).
-// Returns nil if no available account.
-func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
+// Returns nil if no available account. The second return reports whether at
+// least one candidate was filtered out solely because it lacks compact support
+// (only meaningful when requireCompact=true).
+func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*Account, bool) {
var selected *Account
+ selectedCompactTier := -1
+ compactBlocked := false
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
for i := range accounts {
@@ -1328,31 +1442,50 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *i
continue
}
- fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
+ fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false)
if fresh == nil {
continue
}
- fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel)
+ fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false)
if fresh == nil {
continue
}
- if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
+ if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
continue
}
+ compactTier := 0
+ if requireCompact {
+ compactTier = openAICompactSupportTier(fresh)
+ if compactTier == 0 {
+ compactBlocked = true
+ continue
+ }
+ }
// 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used
if selected == nil {
selected = fresh
+ selectedCompactTier = compactTier
+ continue
+ }
+
+ // compact 模式下高 tier 优先;同 tier 内才比较 priority/LRU。
+ if requireCompact && compactTier != selectedCompactTier {
+ if compactTier > selectedCompactTier {
+ selected = fresh
+ selectedCompactTier = compactTier
+ }
continue
}
if s.isBetterAccount(fresh, selected) {
selected = fresh
+ selectedCompactTier = compactTier
}
}
- return selected
+ return selected, compactBlocked
}
// isBetterAccount 判断 candidate 是否比 current 更优。
@@ -1390,6 +1523,10 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
+ return s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs, false)
+}
+
+func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*AccountSelectionResult, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
@@ -1406,7 +1543,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
- account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID)
+ account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID)
if err != nil {
return nil, err
}
@@ -1459,12 +1596,11 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if clearSticky {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
}
- if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
- (requestedModel == "" || account.IsModelSupported(requestedModel)) {
- account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
+ if !clearSticky && isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
+ account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
- } else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
+ } else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
@@ -1489,6 +1625,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
}
// ============ Layer 2: Load-aware selection ============
+ baseCandidateCount := 0
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
@@ -1504,9 +1641,10 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
- if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
+ if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel, requireCompact) {
continue
}
+ baseCandidateCount++
candidates = append(candidates, acc)
}
@@ -1526,12 +1664,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if err != nil {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, false)
+ if requireCompact {
+ ordered = prioritizeOpenAICompactAccounts(ordered)
+ }
for _, acc := range ordered {
- fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
+ fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false)
if fresh == nil {
continue
}
- if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
+ fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
+ if fresh == nil {
+ continue
+ }
+ if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
@@ -1579,12 +1724,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
})
shuffleWithinSortGroups(available)
- for _, item := range available {
- fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel)
+ selectionOrder := make([]accountWithLoad, 0, len(available))
+ if requireCompact {
+ appendTier := func(out []accountWithLoad, tier int) []accountWithLoad {
+ for _, item := range available {
+ if openAICompactSupportTier(item.account) == tier {
+ out = append(out, item)
+ }
+ }
+ return out
+ }
+ selectionOrder = appendTier(selectionOrder, 2)
+ selectionOrder = appendTier(selectionOrder, 1)
+ // tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际
+ // 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。
+ selectionOrder = appendTier(selectionOrder, 0)
+ } else {
+ selectionOrder = append(selectionOrder, available...)
+ }
+
+ for _, item := range selectionOrder {
+ fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
if fresh == nil {
continue
}
- if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
+ fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
+ if fresh == nil {
+ continue
+ }
+ if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
@@ -1600,12 +1768,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
// ============ Layer 3: Fallback wait ============
sortAccountsByPriorityAndLastUsed(candidates, false)
+ if requireCompact {
+ candidates = prioritizeOpenAICompactAccounts(candidates)
+ }
for _, acc := range candidates {
- fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel)
+ fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false)
if fresh == nil {
continue
}
- if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
+ fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
+ if fresh == nil {
+ continue
+ }
+ if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
continue
}
return s.newSelectionResult(ctx, fresh, false, nil, &AccountWaitPlan{
@@ -1616,6 +1791,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
})
}
+ if requireCompact && baseCandidateCount > 0 {
+ return nil, ErrNoAvailableCompactAccounts
+ }
return nil, ErrNoAvailableAccounts
}
@@ -1646,7 +1824,7 @@ func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accoun
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
-func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string) *Account {
+func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account {
if account == nil {
return nil
}
@@ -1660,20 +1838,20 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
fresh = current
}
- if !fresh.IsSchedulable() || !fresh.IsOpenAI() {
- return nil
- }
- if requestedModel != "" && !fresh.IsModelSupported(requestedModel) {
+ if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) {
return nil
}
return fresh
}
-func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string) *Account {
+func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account {
if account == nil {
return nil
}
if s.schedulerSnapshot == nil || s.accountRepo == nil {
+ if !isOpenAIAccountEligibleForRequest(account, requestedModel, requireCompact) {
+ return nil
+ }
return account
}
@@ -1681,10 +1859,7 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
if err != nil || latest == nil {
return nil
}
- if !latest.IsSchedulable() || !latest.IsOpenAI() {
- return nil
- }
- if requestedModel != "" && !latest.IsModelSupported(requestedModel) {
+ if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) {
return nil
}
return latest
@@ -1933,6 +2108,23 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
markPatchSet("instructions", "You are a helpful coding assistant.")
}
+ if isCodexCLI && ensureOpenAIResponsesImageGenerationTool(reqBody) {
+ bodyModified = true
+ disablePatch()
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client")
+ }
+
+ if normalizeOpenAIResponsesImageGenerationTools(reqBody) {
+ bodyModified = true
+ disablePatch()
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
+ }
+ if isCodexCLI && applyCodexImageGenerationBridgeInstructions(reqBody) {
+ bodyModified = true
+ disablePatch()
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions")
+ }
+
// 对所有请求执行模型映射(包含 Codex CLI)。
billingModel := account.GetMappedModel(reqModel)
if billingModel != reqModel {
@@ -1942,18 +2134,81 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
markPatchSet("model", billingModel)
}
upstreamModel := billingModel
+ if normalizeOpenAIResponsesImageOnlyModel(reqBody) {
+ bodyModified = true
+ disablePatch()
+ if model, ok := reqBody["model"].(string); ok {
+ upstreamModel = strings.TrimSpace(model)
+ }
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[OpenAI] Normalized /responses image-only model request inbound_model=%s image_model=%s upstream_model=%s",
+ reqModel,
+ billingModel,
+ upstreamModel,
+ )
+ }
+ if err := validateOpenAIResponsesImageModel(reqBody, upstreamModel); err != nil {
+ setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "")
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": gin.H{
+ "type": "invalid_request_error",
+ "message": err.Error(),
+ "param": "model",
+ },
+ })
+ return nil, err
+ }
+ if hasOpenAIImageGenerationTool(reqBody) {
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[OpenAI] /responses image_generation request inbound_model=%s mapped_model=%s account_type=%s",
+ reqModel,
+ upstreamModel,
+ account.Type,
+ )
+ }
+ if err := validateCodexSparkInput(reqBody, upstreamModel); err != nil {
+ setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "")
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": gin.H{
+ "type": "invalid_request_error",
+ "message": err.Error(),
+ "param": "input",
+ },
+ })
+ return nil, err
+ }
+
+ // Compact-only model 映射:仅在 /responses/compact 路径生效,且优先级高于
+ // OAuth 模型规范化(避免 OAuth 规范化覆盖 compact-only 自定义模型)。
+ isCompactRequest := isOpenAIResponsesCompactPath(c)
+ compactMapped := false
+ if isCompactRequest {
+ compactMappedModel := resolveOpenAICompactForwardModel(account, billingModel)
+ if compactMappedModel != "" && compactMappedModel != billingModel {
+ compactMapped = true
+ upstreamModel = compactMappedModel
+ reqBody["model"] = compactMappedModel
+ bodyModified = true
+ markPatchSet("model", compactMappedModel)
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Compact model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", billingModel, compactMappedModel, account.Name, isCodexCLI)
+ }
+ }
// OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为
// 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名,
// 以兼容自定义 base_url 的 OpenAI-compatible 上游。
if model, ok := reqBody["model"].(string); ok {
- upstreamModel = normalizeOpenAIModelForUpstream(account, model)
- if upstreamModel != "" && upstreamModel != model {
- logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
- model, upstreamModel, account.Name, account.Type, isCodexCLI)
- reqBody["model"] = upstreamModel
- bodyModified = true
- markPatchSet("model", upstreamModel)
+ if !compactMapped {
+ upstreamModel = normalizeOpenAIModelForUpstream(account, model)
+ if upstreamModel != "" && upstreamModel != model {
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)",
+ model, upstreamModel, account.Name, account.Type, isCodexCLI)
+ reqBody["model"] = upstreamModel
+ bodyModified = true
+ markPatchSet("model", upstreamModel)
+ }
}
// 移除 gpt-5.2-codex 以下的版本 verbosity 参数
@@ -1976,7 +2231,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
if account.Type == AccountTypeOAuth {
- codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isOpenAIResponsesCompactPath(c))
+ codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI, isCompactRequest)
if codexResult.Modified {
bodyModified = true
disablePatch()
@@ -2058,6 +2313,48 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
disablePatch()
}
+ // Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤):
+ // 针对 body 的 service_tier 字段("priority" 即 fast,"flex"),按策略
+ // 执行 filter(删除字段)或 block(拒绝请求)。对 gpt-5.5 等模型屏蔽
+ // fast 时在此生效。
+ //
+ // 注意:
+ // 1. 此处统一使用 upstreamModel(已经过 GetMappedModel +
+ // normalizeOpenAIModelForUpstream + Codex OAuth normalize),与
+ // chat-completions / messages 入口保持一致,避免不同入口因为模型
+ // 维度不同而出现 whitelist 命中差异。
+ // 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body,
+ // 否则 native /responses 入口透传 "fast" 给上游会被拒。chat-
+ // completions 入口由 normalizeResponsesBodyServiceTier 完成同一
+ // 行为,这里手工实现等效逻辑。
+ if rawTier, ok := reqBody["service_tier"].(string); ok {
+ if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" {
+ action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier)
+ switch action {
+ case BetaPolicyActionBlock:
+ msg := errMsg
+ if msg == "" {
+ msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel)
+ }
+ blocked := &OpenAIFastBlockedError{Message: msg}
+ writeOpenAIFastPolicyBlockedResponse(c, blocked)
+ return nil, blocked
+ case BetaPolicyActionFilter:
+ delete(reqBody, "service_tier")
+ bodyModified = true
+ disablePatch()
+ default:
+ // pass:若客户端传的是别名 "fast",归一化为 "priority"
+ // 后写回 body,确保上游收到的是其能识别的规范值。
+ if normTier != rawTier {
+ reqBody["service_tier"] = normTier
+ bodyModified = true
+ markPatchSet("service_tier", normTier)
+ }
+ }
+ }
+ }
+
// Re-serialize body only if modified
if bodyModified {
serializedByPatch := false
@@ -2451,6 +2748,19 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
reqStream bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
+ upstreamPassthroughModel := ""
+ if isOpenAIResponsesCompactPath(c) {
+ compactMappedModel := resolveOpenAICompactForwardModel(account, reqModel)
+ if compactMappedModel != "" && compactMappedModel != reqModel {
+ nextBody, setErr := sjson.SetBytes(body, "model", compactMappedModel)
+ if setErr != nil {
+ return nil, fmt.Errorf("set compact passthrough model: %w", setErr)
+ }
+ body = nextBody
+ upstreamPassthroughModel = compactMappedModel
+ }
+ }
+
if account != nil && account.Type == AccountTypeOAuth {
if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" {
rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field"
@@ -2493,6 +2803,26 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
body = sanitizedBody
}
+ // Apply OpenAI fast policy to the passthrough body (filter/block by service_tier).
+ // 统一使用 upstream 视角的 model:透传路径下 body 已经过 compact 映射 +
+ // OAuth normalize,body 中的 model 字段即上游真正会看到的 slug。
+ // 这样可以与 chat-completions / messages / native /responses 入口的
+ // upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有
+ // model 字段时退回 reqModel。
+ policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String())
+ if policyModel == "" {
+ policyModel = reqModel
+ }
+ updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body)
+ if policyErr != nil {
+ var blocked *OpenAIFastBlockedError
+ if errors.As(policyErr, &blocked) {
+ writeOpenAIFastPolicyBlockedResponse(c, blocked)
+ }
+ return nil, policyErr
+ }
+ body = updatedBody
+
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID,
@@ -2576,14 +2906,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
- result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime)
+ result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel)
if err != nil {
return nil, err
}
usage = result.usage
firstTokenMs = result.firstTokenMs
} else {
- usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c)
+ usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel)
if err != nil {
return nil, err
}
@@ -2601,6 +2931,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: reqModel,
+ UpstreamModel: upstreamPassthroughModel,
ServiceTier: extractOpenAIServiceTierFromBody(body),
ReasoningEffort: reasoningEffort,
Stream: reqStream,
@@ -2904,12 +3235,121 @@ type openaiStreamingResultPassthrough struct {
firstTokenMs *int
}
+func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool {
+ if localStarted {
+ return true
+ }
+ return c != nil && c.Writer != nil && c.Writer.Written()
+}
+
+func openAIStreamEventIsPreamble(eventType string) bool {
+ switch strings.TrimSpace(eventType) {
+ case "response.created", "response.in_progress":
+ return true
+ default:
+ return false
+ }
+}
+
+func openAIStreamDataStartsClientOutput(data, eventType string) bool {
+ trimmed := strings.TrimSpace(data)
+ if trimmed == "" {
+ return false
+ }
+ if strings.TrimSpace(eventType) == "response.failed" {
+ return false
+ }
+ return !openAIStreamEventIsPreamble(eventType)
+}
+
+func openAIStreamFailedEventShouldFailover(payload []byte, message string) bool {
+ code := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.code").String()))
+ if code == "" {
+ code = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.code").String()))
+ }
+ errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.type").String()))
+ if errType == "" {
+ errType = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.type").String()))
+ }
+ combined := strings.ToLower(strings.TrimSpace(message + " " + code + " " + errType))
+ if combined == "" {
+ return true
+ }
+ nonRetryableMarkers := []string{
+ "invalid_request",
+ "content_policy",
+ "policy",
+ "safety",
+ "high-risk cyber",
+ "not allowed",
+ "violat",
+ }
+ for _, marker := range nonRetryableMarkers {
+ if strings.Contains(combined, marker) {
+ return false
+ }
+ }
+ return true
+}
+
+func (s *OpenAIGatewayService) newOpenAIStreamFailoverError(
+ c *gin.Context,
+ account *Account,
+ passthrough bool,
+ upstreamRequestID string,
+ payload []byte,
+ message string,
+) *UpstreamFailoverError {
+ message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message))
+ if message == "" {
+ message = "OpenAI stream disconnected before completion"
+ }
+ detail := ""
+ if len(payload) > 0 && s != nil && s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
+ maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
+ if maxBytes <= 0 {
+ maxBytes = 2048
+ }
+ detail = truncateString(string(payload), maxBytes)
+ }
+ if c != nil {
+ setOpsUpstreamError(c, http.StatusBadGateway, message, detail)
+ event := OpsUpstreamErrorEvent{
+ Platform: PlatformOpenAI,
+ UpstreamStatusCode: http.StatusBadGateway,
+ UpstreamRequestID: strings.TrimSpace(upstreamRequestID),
+ Passthrough: passthrough,
+ Kind: "failover",
+ Message: message,
+ Detail: detail,
+ }
+ if account != nil {
+ event.Platform = account.Platform
+ event.AccountID = account.ID
+ event.AccountName = account.Name
+ }
+ appendOpsUpstreamError(c, event)
+ }
+ body, _ := json.Marshal(gin.H{
+ "error": gin.H{
+ "type": "upstream_error",
+ "message": message,
+ },
+ })
+ return &UpstreamFailoverError{
+ StatusCode: http.StatusBadGateway,
+ ResponseBody: body,
+ }
+}
+
func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
ctx context.Context,
resp *http.Response,
c *gin.Context,
account *Account,
startTime time.Time,
+ originalModel string,
+ mappedModel string,
) (*openaiStreamingResultPassthrough, error) {
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
@@ -2933,7 +3373,22 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
clientDisconnected := false
sawDone := false
sawTerminalEvent := false
+ sawFailedEvent := false
+ failedMessage := ""
+ clientOutputStarted := false
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
+ pendingLines := make([]string, 0, 8)
+ writePendingLines := func() bool {
+ for _, pending := range pendingLines {
+ if _, err := fmt.Fprintln(w, pending); err != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
+ return false
+ }
+ }
+ pendingLines = pendingLines[:0]
+ return true
+ }
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
@@ -2944,18 +3399,40 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
scanner.Buffer(scanBuf[:0], maxLineSize)
defer putSSEScannerBuf64K(scanBuf)
+ needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel)
+
for scanner.Scan() {
line := scanner.Text()
+ lineStartsClientOutput := false
+ forceFlushFailedEvent := false
if data, ok := extractOpenAISSEDataLine(line); ok {
dataBytes := []byte(data)
trimmedData := strings.TrimSpace(data)
+ if needModelReplace && strings.Contains(data, mappedModel) {
+ line = s.replaceModelInSSELine(line, mappedModel, originalModel)
+ if replacedData, replaced := extractOpenAISSEDataLine(line); replaced {
+ dataBytes = []byte(replacedData)
+ trimmedData = strings.TrimSpace(replacedData)
+ }
+ }
+ eventType := strings.TrimSpace(gjson.Get(trimmedData, "type").String())
+ if eventType == "response.failed" {
+ failedMessage = extractOpenAISSEErrorMessage(dataBytes)
+ if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
+ s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage)
+ }
+ forceFlushFailedEvent = true
+ sawFailedEvent = true
+ }
if trimmedData == "[DONE]" {
sawDone = true
}
if openAIStreamEventIsTerminal(trimmedData) {
sawTerminalEvent = true
}
- if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
+ lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType)
+ if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
@@ -2963,20 +3440,30 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
if !clientDisconnected {
+ if !clientOutputStarted && !lineStartsClientOutput {
+ pendingLines = append(pendingLines, line)
+ continue
+ }
+ if !clientOutputStarted && len(pendingLines) > 0 {
+ if !writePendingLines() {
+ continue
+ }
+ }
if _, err := fmt.Fprintln(w, line); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
} else {
+ clientOutputStarted = true
flusher.Flush()
}
}
}
if err := scanner.Err(); err != nil {
- if sawTerminalEvent {
+ if sawTerminalEvent && !sawFailedEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
- if clientDisconnected {
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
+ if sawFailedEvent {
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
@@ -2985,6 +3472,17 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
}
+ if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
+ msg := "OpenAI stream disconnected before completion"
+ if errText := strings.TrimSpace(err.Error()); errText != "" {
+ msg += ": " + errText
+ }
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
+ s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg)
+ }
+ if clientDisconnected {
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
+ }
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
account.ID,
@@ -2993,12 +3491,19 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
+ if sawFailedEvent {
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
+ }
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
logger.FromContext(ctx).With(
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", account.ID),
zap.String("upstream_request_id", upstreamRequestID),
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
+ if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
+ s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event")
+ }
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
}
@@ -3009,6 +3514,8 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
ctx context.Context,
resp *http.Response,
c *gin.Context,
+ originalModel string,
+ mappedModel string,
) (*OpenAIUsage, error) {
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
@@ -3020,7 +3527,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
// stream=false was requested. Without this conversion the client would
// receive raw SSE text or a terminal event with empty output.
if isEventStreamResponse(resp.Header) {
- return s.handlePassthroughSSEToJSON(resp, c, body)
+ return s.handlePassthroughSSEToJSON(resp, c, body, originalModel, mappedModel)
}
usage := &OpenAIUsage{}
@@ -3042,14 +3549,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
if contentType == "" {
contentType = "application/json"
}
+ if originalModel != "" && mappedModel != "" && originalModel != mappedModel {
+ body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
+ }
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
// handlePassthroughSSEToJSON converts an SSE response body into a JSON
-// response for the passthrough path. It mirrors handleSSEToJSON but skips
-// model replacement (passthrough does not remap models).
-func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte) (*OpenAIUsage, error) {
+// response for the passthrough path. It mirrors handleSSEToJSON while
+// preserving passthrough payloads, except compact-only model remapping may
+// rewrite model fields back to the original requested model.
+func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*OpenAIUsage, error) {
bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText)
@@ -3068,6 +3579,9 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c
}
}
body = finalResponse
+ if originalModel != "" && mappedModel != "" && originalModel != mappedModel {
+ body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
+ }
// Correct tool calls in final response
body = s.correctToolCallsInResponseBody(body)
} else {
@@ -3080,6 +3594,10 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c
return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg)
}
usage = s.parseSSEUsageFromBody(bodyText)
+ if originalModel != "" && mappedModel != "" && originalModel != mappedModel {
+ bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
+ }
+ body = []byte(bodyText)
}
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
@@ -3578,8 +4096,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
- // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
- lastDataAt := time.Now()
+ // Track downstream writes separately from upstream reads: pre-output failover
+ // can buffer response.created / response.in_progress, so keepalive must be
+ // based on downstream idle time.
+ lastDownstreamWriteAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱。
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
@@ -3587,6 +4107,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sawTerminalEvent := false
+ sawFailedEvent := false
+ failedMessage := ""
+ clientOutputStarted := false
+ upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
+ var streamFailoverErr error
sendErrorEvent := func(reason string) {
if errorEventSent || clientDisconnected {
return
@@ -3603,7 +4128,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
if err := flushBuffered(); err != nil {
clientDisconnected = true
+ return
}
+ clientOutputStarted = true
+ lastDownstreamWriteAt = time.Now()
}
needModelReplace := originalModel != mappedModel
@@ -3611,45 +4139,73 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
}
finalizeStream := func() (*openaiStreamingResult, error) {
+ if !sawTerminalEvent {
+ if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
+ return resultWithUsage(), s.newOpenAIStreamFailoverError(
+ c,
+ account,
+ false,
+ upstreamRequestID,
+ nil,
+ "OpenAI stream ended before a terminal event",
+ )
+ }
+ return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
+ }
+ if sawFailedEvent {
+ return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
+ }
if !clientDisconnected {
+ hadBufferedData := bufferedWriter.Buffered() > 0
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
+ } else if hadBufferedData {
+ clientOutputStarted = true
+ lastDownstreamWriteAt = time.Now()
}
}
- if !sawTerminalEvent {
- return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
- }
return resultWithUsage(), nil
}
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
if scanErr == nil {
return nil, nil, false
}
- if sawTerminalEvent {
+ if sawTerminalEvent && !sawFailedEvent {
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
return resultWithUsage(), nil, true
}
+ if sawFailedEvent {
+ return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage), true
+ }
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
}
- // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
- if clientDisconnected {
- return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
- }
if errors.Is(scanErr, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
sendErrorEvent("response_too_large")
return resultWithUsage(), scanErr, true
}
+ if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
+ msg := "OpenAI stream disconnected before completion"
+ if errText := strings.TrimSpace(scanErr.Error()); errText != "" {
+ msg += ": " + errText
+ }
+ return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, nil, msg), true
+ }
+ // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
+ if clientDisconnected {
+ return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
+ }
sendErrorEvent("stream_read_error")
return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true
}
processSSELine := func(line string, queueDrained bool) {
- lastDataAt = time.Now()
-
+ if streamFailoverErr != nil {
+ return
+ }
// Extract data from SSE line (supports both "data: " and "data:" formats)
if data, ok := extractOpenAISSEDataLine(line); ok {
@@ -3663,18 +4219,32 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if openAIStreamEventIsTerminal(data) {
sawTerminalEvent = true
}
+ eventType := strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
+ forceFlushFailedEvent := false
+ if eventType == "response.failed" {
+ failedMessage = extractOpenAISSEErrorMessage(dataBytes)
+ if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
+ sawFailedEvent = true
+ streamFailoverErr = s.newOpenAIStreamFailoverError(c, account, false, upstreamRequestID, dataBytes, failedMessage)
+ return
+ }
+ forceFlushFailedEvent = true
+ sawFailedEvent = true
+ }
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
dataBytes = correctedData
data = string(correctedData)
line = "data: " + data
+ eventType = strings.TrimSpace(gjson.GetBytes(dataBytes, "type").String())
}
+ startsClientOutput := forceFlushFailedEvent || openAIStreamDataStartsClientOutput(data, eventType)
// 写入客户端(客户端断开后继续 drain 上游)
if !clientDisconnected {
- shouldFlush := queueDrained
- if firstTokenMs == nil && data != "" && data != "[DONE]" {
+ shouldFlush := queueDrained && (clientOutputStarted || startsClientOutput)
+ if firstTokenMs == nil && startsClientOutput {
// 保证首个 token 事件尽快出站,避免影响 TTFT。
shouldFlush = true
}
@@ -3688,12 +4258,15 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
+ } else {
+ clientOutputStarted = true
+ lastDownstreamWriteAt = time.Now()
}
}
}
// Record first token time
- if firstTokenMs == nil && data != "" && data != "[DONE]" {
+ if firstTokenMs == nil && startsClientOutput {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
@@ -3709,10 +4282,13 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
} else if _, err := bufferedWriter.WriteString("\n"); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
- } else if queueDrained {
+ } else if queueDrained && clientOutputStarted {
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
+ } else {
+ clientOutputStarted = true
+ lastDownstreamWriteAt = time.Now()
}
}
}
@@ -3723,6 +4299,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
defer putSSEScannerBuf64K(scanBuf)
for scanner.Scan() {
processSSELine(scanner.Text(), true)
+ if streamFailoverErr != nil {
+ return resultWithUsage(), streamFailoverErr
+ }
}
if result, err, done := handleScanErr(scanner.Err()); done {
return result, err
@@ -3772,6 +4351,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return result, err
}
processSSELine(ev.line, len(events) == 0)
+ if streamFailoverErr != nil {
+ return resultWithUsage(), streamFailoverErr
+ }
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
@@ -3793,7 +4375,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if clientDisconnected {
continue
}
- if time.Since(lastDataAt) < keepaliveInterval {
+ if time.Since(lastDownstreamWriteAt) < keepaliveInterval {
continue
}
if _, err := bufferedWriter.WriteString(":\n\n"); err != nil {
@@ -3804,6 +4386,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if err := flushBuffered(); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing")
+ } else {
+ lastDownstreamWriteAt = time.Now()
}
}
}
@@ -3882,13 +4466,15 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
return
}
eventType := gjson.GetBytes(data, "type").String()
- if eventType != "response.completed" && eventType != "response.done" {
+ if eventType != "response.completed" && eventType != "response.done" &&
+ eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" {
return
}
usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int())
usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int())
usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int())
+ usage.ImageOutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens_details.image_tokens").Int())
}
func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
@@ -3900,11 +4486,13 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
"usage.input_tokens",
"usage.output_tokens",
"usage.input_tokens_details.cached_tokens",
+ "usage.output_tokens_details.image_tokens",
)
return OpenAIUsage{
InputTokens: int(values[0].Int()),
OutputTokens: int(values[1].Int()),
CacheReadInputTokens: int(values[2].Int()),
+ ImageOutputTokens: int(values[3].Int()),
}, true
}
@@ -4026,7 +4614,7 @@ func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
}
eventType := strings.TrimSpace(gjson.Get(data, "type").String())
switch eventType {
- case "response.completed", "response.done", "response.failed":
+ case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
return eventType, []byte(data), true
}
}
@@ -4087,22 +4675,39 @@ func extractCodexFinalResponse(body string) ([]byte, bool) {
// Returns (nil, false) if no content was found in deltas.
func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) {
acc := apicompat.NewBufferedResponseAccumulator()
+ imageOutputs := make([]json.RawMessage, 0, 1)
+ seenImages := make(map[string]struct{})
lines := strings.Split(bodyText, "\n")
for _, line := range lines {
data, ok := extractOpenAISSEDataLine(line)
if !ok || data == "" || data == "[DONE]" {
continue
}
+ if imageOutput, ok := extractImageGenerationOutputFromSSEData([]byte(data), seenImages); ok {
+ imageOutputs = append(imageOutputs, imageOutput)
+ }
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(data), &event); err != nil {
continue
}
acc.ProcessEvent(&event)
}
- if !acc.HasContent() {
+ if !acc.HasContent() && len(imageOutputs) == 0 {
return nil, false
}
- output := acc.BuildOutput()
+
+ var output []json.RawMessage
+ if acc.HasContent() {
+ outputJSON, err := json.Marshal(acc.BuildOutput())
+ if err == nil {
+ _ = json.Unmarshal(outputJSON, &output)
+ }
+ }
+ output = append(output, imageOutputs...)
+ if len(output) == 0 {
+ return nil, false
+ }
+
outputJSON, err := json.Marshal(output)
if err != nil {
return nil, false
@@ -4110,6 +4715,33 @@ func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) {
return outputJSON, true
}
+func extractImageGenerationOutputFromSSEData(data []byte, seen map[string]struct{}) (json.RawMessage, bool) {
+ if len(data) == 0 || !gjson.ValidBytes(data) {
+ return nil, false
+ }
+ if gjson.GetBytes(data, "type").String() != "response.output_item.done" {
+ return nil, false
+ }
+ item := gjson.GetBytes(data, "item")
+ if !item.Exists() || !item.IsObject() || item.Get("type").String() != "image_generation_call" {
+ return nil, false
+ }
+ if strings.TrimSpace(item.Get("result").String()) == "" {
+ return nil, false
+ }
+ key := strings.TrimSpace(item.Get("id").String())
+ if key == "" {
+ key = strings.TrimSpace(item.Get("output_format").String()) + "|" + strings.TrimSpace(item.Get("result").String())
+ }
+ if key != "" && seen != nil {
+ if _, exists := seen[key]; exists {
+ return nil, false
+ }
+ seen[key] = struct{}{}
+ }
+ return json.RawMessage(item.Raw), true
+}
+
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
usage := &OpenAIUsage{}
lines := strings.Split(body, "\n")
@@ -4297,7 +4929,18 @@ func normalizeOpenAICompactRequestBody(body []byte) ([]byte, bool, error) {
}
normalized := []byte(`{}`)
- for _, field := range []string{"model", "input", "instructions", "previous_response_id"} {
+ // Keep the current Codex /compact schema while still dropping request-scoped
+ // fields such as prompt_cache_key, store, and stream.
+ for _, field := range []string{
+ "model",
+ "input",
+ "instructions",
+ "tools",
+ "parallel_tool_calls",
+ "reasoning",
+ "text",
+ "previous_response_id",
+ } {
value := gjson.GetBytes(body, field)
if !value.Exists() {
continue
@@ -4394,10 +5037,14 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result
+ if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
+ s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
+ }
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&
- result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 {
+ result.Usage.CacheCreationInputTokens == 0 && result.Usage.CacheReadInputTokens == 0 &&
+ result.Usage.ImageOutputTokens == 0 && result.ImageCount == 0 {
return nil
}
@@ -4451,21 +5098,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier)
}
- if s.resolver != nil && apiKey.Group != nil {
- gid := apiKey.Group.ID
- cost, err = s.billingService.CalculateCostUnified(CostInput{
- Ctx: ctx,
- Model: billingModel,
- GroupID: &gid,
- Tokens: tokens,
- RequestCount: 1,
- RateMultiplier: multiplier,
- ServiceTier: serviceTier,
- Resolver: s.resolver,
- })
- } else {
- cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
- }
+ cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier)
if err != nil {
cost = &CostBreakdown{ActualCost: 0}
}
@@ -4505,6 +5138,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
+ ImageCount: result.ImageCount,
+ ImageSize: optionalTrimmedStringPtr(result.ImageSize),
}
if cost != nil {
usageLog.InputCost = cost.InputCost
@@ -4530,6 +5165,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
+ } else if result.ImageCount > 0 {
+ billingMode := string(BillingModeImage)
+ usageLog.BillingMode = &billingMode
} else {
billingMode := string(BillingModeToken)
usageLog.BillingMode = &billingMode
@@ -4589,6 +5227,83 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
return nil
}
+func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
+ ctx context.Context,
+ result *OpenAIForwardResult,
+ apiKey *APIKey,
+ billingModel string,
+ multiplier float64,
+ tokens UsageTokens,
+ serviceTier string,
+) (*CostBreakdown, error) {
+ if result != nil && result.ImageCount > 0 {
+ return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
+ }
+ if s.resolver != nil && apiKey.Group != nil {
+ gid := apiKey.Group.ID
+ return s.billingService.CalculateCostUnified(CostInput{
+ Ctx: ctx,
+ Model: billingModel,
+ GroupID: &gid,
+ Tokens: tokens,
+ RequestCount: 1,
+ RateMultiplier: multiplier,
+ ServiceTier: serviceTier,
+ Resolver: s.resolver,
+ })
+ }
+ return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
+}
+
+func (s *OpenAIGatewayService) calculateOpenAIImageCost(
+ ctx context.Context,
+ billingModel string,
+ apiKey *APIKey,
+ result *OpenAIForwardResult,
+ multiplier float64,
+) *CostBreakdown {
+ if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil &&
+ (resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) {
+ gid := apiKey.Group.ID
+ cost, err := s.billingService.CalculateCostUnified(CostInput{
+ Ctx: ctx,
+ Model: billingModel,
+ GroupID: &gid,
+ RequestCount: 1,
+ SizeTier: result.ImageSize,
+ RateMultiplier: multiplier,
+ Resolver: s.resolver,
+ Resolved: resolved,
+ })
+ if err == nil {
+ return cost
+ }
+ logger.LegacyPrintf("service.openai_gateway", "Calculate image channel cost failed: %v", err)
+ }
+
+ var groupConfig *ImagePriceConfig
+ if apiKey != nil && apiKey.Group != nil {
+ groupConfig = &ImagePriceConfig{
+ Price1K: apiKey.Group.ImagePrice1K,
+ Price2K: apiKey.Group.ImagePrice2K,
+ Price4K: apiKey.Group.ImagePrice4K,
+ }
+ }
+ return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
+}
+
+func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
+ if s.resolver == nil || apiKey == nil || apiKey.Group == nil {
+ return nil
+ }
+ gid := apiKey.Group.ID
+ resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid})
+ if resolved.Source == PricingSourceChannel {
+ return resolved
+ }
+ return nil
+}
+
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
@@ -4838,7 +5553,8 @@ func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, p
}
// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为:
-// 1) store=false 2) 非 compact 保持 stream=true;compact 强制 stream=false
+// 1) 删除 ChatGPT internal API 不支持的顶层 Responses 参数
+// 2) store=false 3) 非 compact 保持 stream=true;compact 强制 stream=false
func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, bool, error) {
if len(body) == 0 {
return body, false, nil
@@ -4847,6 +5563,18 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte, compact bool) ([]byte, boo
normalized := body
changed := false
+ for _, field := range openAIChatGPTInternalUnsupportedFields {
+ if value := gjson.GetBytes(normalized, field); !value.Exists() {
+ continue
+ }
+ next, err := sjson.DeleteBytes(normalized, field)
+ if err != nil {
+ return body, false, fmt.Errorf("normalize passthrough body delete %s: %w", field, err)
+ }
+ normalized = next
+ changed = true
+ }
+
if compact {
if store := gjson.GetBytes(normalized, "store"); store.Exists() {
next, err := sjson.DeleteBytes(normalized, "store")
@@ -4951,14 +5679,319 @@ func normalizeOpenAIServiceTier(raw string) *string {
if value == "fast" {
value = "priority"
}
+ // 放过 OpenAI 官方文档定义的所有合法 tier 值:priority/flex/auto/default/scale。
+ // 对 Codex 客户端零影响(Codex 只发 priority 或 flex,见 codex-rs/core/src/client.rs),
+ // 但能让直连 OpenAI SDK 的用户透传 auto/default/scale 以便抓包/调试。
+ // 真未知值仍返回 nil,由 normalizeResponsesBodyServiceTier 从 body 中删除。
switch value {
- case "priority", "flex":
+ case "priority", "flex", "auto", "default", "scale":
return &value
default:
return nil
}
}
+// OpenAIFastBlockedError indicates a request was rejected by the OpenAI fast
+// policy (action=block). Mirrors BetaBlockedError on the Claude side.
+type OpenAIFastBlockedError struct {
+ Message string
+}
+
+func (e *OpenAIFastBlockedError) Error() string { return e.Message }
+
+// evaluateOpenAIFastPolicy returns the action and error message that should be
+// applied for a request with the given account/model/service_tier. When the
+// policy service is unavailable or no rule matches, it returns
+// (BetaPolicyActionPass, "") so callers can short-circuit safely.
+//
+// Matching rules:
+// - Scope filters by account type (all / oauth / apikey / bedrock)
+// - ServiceTier must be empty (= any), "all", or equal the normalized tier
+// - ModelWhitelist narrows the rule to specific models; FallbackAction
+// handles the non-matching case (default: pass)
+//
+// 与 Claude BetaPolicy 的差异(保留首条匹配 short-circuit):
+// - BetaPolicy 处理的是 anthropic-beta header 中的 token 集合,不同
+// 规则可能针对不同 token,filter 需要累加成 set;block 则 first-match。
+// - OpenAI fast policy 操作的是单个字段 service_tier:filter 即删字段,
+// 没有可累加的对象。一次请求只携带一个 service_tier,规则的 tier
+// 维度天然互斥;同一 (scope, tier) 下若多条规则的 model whitelist
+// 发生重叠,admin 可通过规则顺序明确意图。因此采用 first-match 而
+// 非 BetaPolicy 那样的"block 覆盖 filter 覆盖 pass"语义。
+func (s *OpenAIGatewayService) evaluateOpenAIFastPolicy(ctx context.Context, account *Account, model, serviceTier string) (action, errMsg string) {
+ if s == nil || s.settingService == nil {
+ return BetaPolicyActionPass, ""
+ }
+ tier := strings.ToLower(strings.TrimSpace(serviceTier))
+ if tier == "" {
+ return BetaPolicyActionPass, ""
+ }
+ settings := openAIFastPolicySettingsFromContext(ctx)
+ if settings == nil {
+ fetched, err := s.settingService.GetOpenAIFastPolicySettings(ctx)
+ if err != nil || fetched == nil {
+ return BetaPolicyActionPass, ""
+ }
+ settings = fetched
+ }
+ return evaluateOpenAIFastPolicyWithSettings(settings, account, model, tier)
+}
+
+// evaluateOpenAIFastPolicyWithSettings is the pure-function core extracted so
+// long-lived sessions (e.g. WS) can prefetch settings once and avoid hitting
+// the settingService on every frame. See WSSession entry and
+// openAIFastPolicySettingsFromContext for the caching glue.
+func evaluateOpenAIFastPolicyWithSettings(settings *OpenAIFastPolicySettings, account *Account, model, tier string) (action, errMsg string) {
+ if settings == nil {
+ return BetaPolicyActionPass, ""
+ }
+ isOAuth := account != nil && account.IsOAuth()
+ isBedrock := account != nil && account.IsBedrock()
+ for _, rule := range settings.Rules {
+ if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
+ continue
+ }
+ ruleTier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
+ if ruleTier != "" && ruleTier != OpenAIFastTierAny && ruleTier != tier {
+ continue
+ }
+ eff := BetaPolicyRule{
+ Action: rule.Action,
+ ErrorMessage: rule.ErrorMessage,
+ ModelWhitelist: rule.ModelWhitelist,
+ FallbackAction: rule.FallbackAction,
+ FallbackErrorMessage: rule.FallbackErrorMessage,
+ }
+ return resolveRuleAction(eff, model)
+ }
+ return BetaPolicyActionPass, ""
+}
+
+// openAIFastPolicyCtxKey 是 context 中预取的 OpenAIFastPolicySettings 缓存
+// 键,仅用于 WebSocket 长会话内多帧复用同一份策略快照,避免每帧 DB 命中。
+//
+// Trade-off:策略变更不会影响当前 WS session(只影响新 session)。这是
+// 有意为之 —— 对长会话来说,"策略一致性"比"立刻生效"更重要,且 Claude
+// BetaPolicy 的 gin.Context 缓存也是同样取舍。需要 hot-reload 时管理员
+// 可以通过踢断 session 强制刷新。
+type openAIFastPolicyCtxKeyType struct{}
+
+var openAIFastPolicyCtxKey = openAIFastPolicyCtxKeyType{}
+
+// withOpenAIFastPolicyContext 将一份 settings 快照绑定到 context,供该 ctx
+// 衍生 goroutine 中的 evaluateOpenAIFastPolicy 复用。
+func withOpenAIFastPolicyContext(ctx context.Context, settings *OpenAIFastPolicySettings) context.Context {
+ if ctx == nil || settings == nil {
+ return ctx
+ }
+ return context.WithValue(ctx, openAIFastPolicyCtxKey, settings)
+}
+
+func openAIFastPolicySettingsFromContext(ctx context.Context) *OpenAIFastPolicySettings {
+ if ctx == nil {
+ return nil
+ }
+ if v, ok := ctx.Value(openAIFastPolicyCtxKey).(*OpenAIFastPolicySettings); ok {
+ return v
+ }
+ return nil
+}
+
+// applyOpenAIFastPolicyToBody applies the OpenAI fast policy to a raw request
+// body. When action=filter it removes the service_tier field; when
+// action=block it returns (body, *OpenAIFastBlockedError). On pass it
+// normalizes the service_tier value (e.g. client alias "fast" → "priority"),
+// rewriting the body so the upstream receives a slug it recognizes.
+//
+// Rationale for normalize-on-pass: chat-completions / messages 入口在调用本
+// 函数之前已经通过 normalizeResponsesBodyServiceTier 把 service_tier 归一化
+// 到了上游可识别值;passthrough(OpenAI 自动透传) / native /responses 等
+// 入口没有这一前置步骤,pass 路径下若不在此处归一化,"fast" 就会被原样
+// 透传到 OpenAI 上游导致 400/拒绝。把归一化收敛到本函数,所有入口行为一致。
+func (s *OpenAIGatewayService) applyOpenAIFastPolicyToBody(ctx context.Context, account *Account, model string, body []byte) ([]byte, error) {
+ if len(body) == 0 {
+ return body, nil
+ }
+ rawTier := gjson.GetBytes(body, "service_tier").String()
+ if rawTier == "" {
+ return body, nil
+ }
+ normTier := normalizedOpenAIServiceTierValue(rawTier)
+ if normTier == "" {
+ return body, nil
+ }
+ action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
+ switch action {
+ case BetaPolicyActionBlock:
+ msg := errMsg
+ if msg == "" {
+ msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
+ }
+ return body, &OpenAIFastBlockedError{Message: msg}
+ case BetaPolicyActionFilter:
+ trimmed, err := sjson.DeleteBytes(body, "service_tier")
+ if err != nil {
+ return body, fmt.Errorf("strip service_tier from body: %w", err)
+ }
+ return trimmed, nil
+ default:
+ // pass:把别名(如 "fast")写回为规范值("priority")。
+ if normTier == rawTier {
+ return body, nil
+ }
+ updated, err := sjson.SetBytes(body, "service_tier", normTier)
+ if err != nil {
+ return body, fmt.Errorf("normalize service_tier on pass: %w", err)
+ }
+ return updated, nil
+ }
+}
+
+// writeOpenAIFastPolicyBlockedResponse writes a 403 JSON response for a
+// request blocked by the OpenAI fast policy.
+func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlockedError) {
+ if c == nil || err == nil {
+ return
+ }
+ c.JSON(http.StatusForbidden, gin.H{
+ "error": gin.H{
+ "type": "permission_error",
+ "message": err.Message,
+ },
+ })
+}
+
+// applyOpenAIFastPolicyToWSResponseCreate evaluates the OpenAI fast policy
+// against a single client→upstream WebSocket frame whose top-level
+// "type"=="response.create". It mirrors the HTTP-side
+// applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses
+// WS payload:
+//
+// - pass: returns frame unchanged (newBytes == frame, blocked == nil)
+// - filter: returns a copy with top-level service_tier removed
+// - block: returns (frame, *OpenAIFastBlockedError)
+//
+// Only frames whose "type" field strictly equals "response.create" are
+// inspected/mutated. Any other frame type — including the empty string —
+// passes through untouched. The OpenAI Realtime client-event spec requires
+// "type" to be set, so an empty type is treated as a malformed frame we do
+// not police; the upstream is the source of truth for rejecting it.
+//
+// service_tier lives at the top level of response.create — same as the
+// Responses HTTP body shape (see openai_gateway_chat_completions.go:304 +
+// extractOpenAIServiceTierFromBody at line 5593, and the test fixture at
+// openai_ws_forwarder_ingress_session_test.go:402). We therefore only need
+// to inspect / strip the top-level field; there is no nested form in the
+// schema today.
+//
+// The caller is responsible for choosing the upstream model passed in —
+// this helper does not re-derive it.
+func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate(
+ ctx context.Context,
+ account *Account,
+ model string,
+ frame []byte,
+) ([]byte, *OpenAIFastBlockedError, error) {
+ if len(frame) == 0 {
+ return frame, nil, nil
+ }
+ if !gjson.ValidBytes(frame) {
+ return frame, nil, nil
+ }
+ frameType := strings.TrimSpace(gjson.GetBytes(frame, "type").String())
+ // Strict match: only response.create is policy-checked. Empty / other
+ // types pass through untouched so we never accidentally strip fields
+ // from response.cancel, conversation.item.create, or any future
+ // client-event the spec adds. The Realtime spec requires "type" on
+ // every client event, so an empty type is malformed input — let the
+ // upstream reject it rather than guessing at our layer.
+ if frameType != "response.create" {
+ return frame, nil, nil
+ }
+ rawTier := gjson.GetBytes(frame, "service_tier").String()
+ if rawTier == "" {
+ return frame, nil, nil
+ }
+ normTier := normalizedOpenAIServiceTierValue(rawTier)
+ if normTier == "" {
+ return frame, nil, nil
+ }
+ action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, model, normTier)
+ switch action {
+ case BetaPolicyActionBlock:
+ msg := errMsg
+ if msg == "" {
+ msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, model)
+ }
+ return frame, &OpenAIFastBlockedError{Message: msg}, nil
+ case BetaPolicyActionFilter:
+ trimmed, err := sjson.DeleteBytes(frame, "service_tier")
+ if err != nil {
+ return frame, nil, fmt.Errorf("strip service_tier from ws frame: %w", err)
+ }
+ return trimmed, nil, nil
+ default:
+ return frame, nil, nil
+ }
+}
+
+// newOpenAIFastPolicyWSEventID returns a Realtime-style event_id for a
+// server-emitted error event. Matches the loose "evt_" convention used
+// by upstream Realtime servers; the exact value is not load-bearing and is
+// only required for client-side log correlation. We reuse the existing
+// google/uuid dependency rather than pulling a new one.
+func newOpenAIFastPolicyWSEventID() string {
+ id, err := uuid.NewRandom()
+ if err != nil {
+ // Extremely unlikely; fall back to a fixed prefix so the field is
+ // still non-empty and the schema stays self-consistent.
+ return "evt_openai_fast_policy"
+ }
+ // Strip dashes so it visually matches "evt_" rather than UUID v4
+ // canonical form, mirroring what real Realtime traces look like.
+ return "evt_" + strings.ReplaceAll(id.String(), "-", "")
+}
+
+// buildOpenAIFastPolicyBlockedWSEvent renders an OpenAI Realtime/Responses
+// style "error" event payload for a request blocked by the OpenAI fast
+// policy. The shape mirrors Realtime error events as observed in upstream
+// traces and per the spec's server "error" event:
+//
+// {
+// "event_id": "evt_",
+// "type": "error",
+// "error": {
+// "type": "invalid_request_error",
+// "code": "policy_violation",
+// "message": "..."
+// }
+// }
+//
+// event_id lets clients correlate the rejection in their logs; "code" gives
+// programmatic clients a stable identifier (HTTP-side equivalent is the
+// 403 permission_error JSON body).
+func buildOpenAIFastPolicyBlockedWSEvent(err *OpenAIFastBlockedError) []byte {
+ if err == nil {
+ return nil
+ }
+ eventID := newOpenAIFastPolicyWSEventID()
+ payload, mErr := json.Marshal(map[string]any{
+ "event_id": eventID,
+ "type": "error",
+ "error": map[string]any{
+ "type": "invalid_request_error",
+ "code": "policy_violation",
+ "message": err.Message,
+ },
+ })
+ if mErr != nil {
+ // Fallback to a minimal hand-rolled payload; Marshal of the literal
+ // shape above should never fail in practice.
+ return []byte(`{"event_id":"` + eventID + `","type":"error","error":{"type":"invalid_request_error","code":"policy_violation","message":"openai fast policy blocked this request"}}`)
+ }
+ return payload
+}
+
func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) {
if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) {
return body, false, nil
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index cf2d875f..b55f0d2c 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -18,6 +18,7 @@ import (
"github.com/cespare/xxhash/v2"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
)
// 编译期接口断言
@@ -92,6 +93,13 @@ type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
+type errReadCloser struct {
+ err error
+}
+
+func (r errReadCloser) Read([]byte) (int, error) { return 0, r.err }
+func (r errReadCloser) Close() error { return nil }
+
type failingGinWriter struct {
gin.ResponseWriter
failAfter int
@@ -219,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
}
+func TestOpenAIGatewayService_GenerateExplicitSessionHash_SkipsContentFallback(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ svc := &OpenAIGatewayService{}
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat"}`)
+
+ t.Run("stateless image body stays unstuck", func(t *testing.T) {
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
+
+ require.Empty(t, svc.GenerateExplicitSessionHash(c, body))
+ require.Empty(t, openAILegacySessionHashFromContext(c.Request.Context()))
+ })
+
+ t.Run("prompt_cache_key is explicit", func(t *testing.T) {
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
+
+ got := svc.GenerateExplicitSessionHash(c, []byte(`{"model":"gpt-image-2","prompt_cache_key":"image-session"}`))
+ require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("image-session")), got)
+ require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
+ })
+
+ t.Run("header overrides body", func(t *testing.T) {
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
+ c.Request.Header.Set("session_id", "header-session")
+
+ got := svc.GenerateExplicitSessionHash(c, []byte(`{"prompt_cache_key":"body-session"}`))
+ require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("header-session")), got)
+ })
+}
+
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
@@ -1002,6 +1045,190 @@ func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErr
}
}
+func TestOpenAIStreamingReadErrorBeforeOutputReturnsFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: errReadCloser{err: io.ErrUnexpectedEOF},
+ Header: http.Header{"X-Request-Id": []string{"rid-disconnect"}},
+ }
+
+ _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
+ require.Error(t, err)
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
+ require.False(t, c.Writer.Written())
+ require.Empty(t, rec.Body.String())
+}
+
+func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(strings.Join([]string{
+ "event: response.created",
+ `data: {"type":"response.created","response":{"id":"resp_1"}}`,
+ "",
+ "event: response.in_progress",
+ `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
+ "",
+ "event: response.failed",
+ `data: {"type":"response.failed","error":{"message":"An error occurred while processing your request."}}`,
+ "",
+ }, "\n"))),
+ Header: http.Header{"X-Request-Id": []string{"rid-failed"}},
+ }
+
+ _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
+ require.Error(t, err)
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
+ require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
+ require.False(t, c.Writer.Written())
+ require.Empty(t, rec.Body.String())
+}
+
+func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(strings.Join([]string{
+ "event: response.created",
+ `data: {"type":"response.created","response":{"id":"resp_1"}}`,
+ "",
+ "event: response.in_progress",
+ `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
+ "",
+ }, "\n"))),
+ Header: http.Header{"X-Request-Id": []string{"rid-missing-terminal"}},
+ }
+
+ _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
+ require.Error(t, err)
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.False(t, c.Writer.Written())
+ require.Empty(t, rec.Body.String())
+}
+
+func TestOpenAIStreamingPreambleKeepaliveUsesDownstreamIdle(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 1,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ pr, pw := io.Pipe()
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: pr,
+ Header: http.Header{},
+ }
+
+ go func() {
+ defer func() { _ = pw.Close() }()
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
+ for i := 0; i < 6; i++ {
+ time.Sleep(250 * time.Millisecond)
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_1\"}}\n\n"))
+ }
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}}\n\n"))
+ }()
+
+ result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
+ _ = pr.Close()
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Contains(t, rec.Body.String(), ":\n\n")
+ require.Contains(t, rec.Body.String(), "response.completed")
+}
+
+func TestOpenAIStreamingPolicyResponseFailedBeforeOutputPassesThrough(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(strings.Join([]string{
+ "event: response.created",
+ `data: {"type":"response.created","response":{"id":"resp_1"}}`,
+ "",
+ "event: response.failed",
+ `data: {"type":"response.failed","error":{"type":"safety_error","message":"This request has been flagged for potentially high-risk cyber activity."}}`,
+ "",
+ }, "\n"))),
+ Header: http.Header{"X-Request-Id": []string{"rid-policy-failed"}},
+ }
+
+ _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
+ require.Error(t, err)
+ var failoverErr *UpstreamFailoverError
+ require.False(t, errors.As(err, &failoverErr))
+ require.True(t, c.Writer.Written())
+ require.Contains(t, rec.Body.String(), "response.failed")
+ require.Contains(t, rec.Body.String(), "high-risk cyber activity")
+}
+
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -1071,7 +1298,7 @@ func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T)
go func() {
defer func() { _ = pw.Close() }()
- _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
@@ -1103,16 +1330,52 @@ func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t
go func() {
defer func() { _ = pw.Close() }()
- _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"message\"},\"output_index\":0}\n\n"))
}()
- _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
+ _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
t.Fatalf("expected missing terminal event error, got %v", err)
}
}
+func TestOpenAIStreamingPassthroughResponseFailedBeforeOutputReturnsFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(strings.Join([]string{
+ "event: response.created",
+ `data: {"type":"response.created","response":{"id":"resp_1"}}`,
+ "",
+ "event: response.failed",
+ `data: {"type":"response.failed","error":{"message":"upstream processing failed"}}`,
+ "",
+ }, "\n"))),
+ Header: http.Header{"X-Request-Id": []string{"rid-passthrough-failed"}},
+ }
+
+ _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "", "")
+ require.Error(t, err)
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
+ require.Contains(t, string(failoverErr.ResponseBody), "upstream processing failed")
+ require.False(t, c.Writer.Written())
+ require.Empty(t, rec.Body.String())
+}
+
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -1138,7 +1401,42 @@ func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t
_, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
- result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
+ result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
+ _ = pr.Close()
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.usage)
+ require.Equal(t, 2, result.usage.InputTokens)
+ require.Equal(t, 3, result.usage.OutputTokens)
+ require.Equal(t, 1, result.usage.CacheReadInputTokens)
+}
+
+func TestOpenAIStreamingPassthroughResponseIncompleteWithoutDoneMarkerStillSucceeds(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ pr, pw := io.Pipe()
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: pr,
+ Header: http.Header{},
+ }
+
+ go func() {
+ defer func() { _ = pw.Close() }()
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.incomplete\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
+ }()
+
+ result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "", "")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
@@ -1469,6 +1767,24 @@ func TestOpenAIResponsesRequestPathSuffix(t *testing.T) {
}
}
+func TestNormalizeOpenAICompactRequestBodyPreservesCurrentCodexPayloadFields(t *testing.T) {
+ body := []byte(`{"model":"gpt-5.5","input":[{"type":"message","role":"user","content":"compact me"}],"instructions":"compact-test","tools":[{"type":"function","name":"shell"}],"parallel_tool_calls":true,"reasoning":{"effort":"high"},"text":{"verbosity":"low"},"previous_response_id":"resp_123","store":true,"stream":true,"prompt_cache_key":"cache_123"}`)
+
+ normalized, changed, err := normalizeOpenAICompactRequestBody(body)
+
+ require.NoError(t, err)
+ require.True(t, changed)
+ require.Equal(t, "gpt-5.5", gjson.GetBytes(normalized, "model").String())
+ require.True(t, gjson.GetBytes(normalized, "tools").Exists())
+ require.True(t, gjson.GetBytes(normalized, "parallel_tool_calls").Bool())
+ require.Equal(t, "high", gjson.GetBytes(normalized, "reasoning.effort").String())
+ require.Equal(t, "low", gjson.GetBytes(normalized, "text.verbosity").String())
+ require.Equal(t, "resp_123", gjson.GetBytes(normalized, "previous_response_id").String())
+ require.False(t, gjson.GetBytes(normalized, "store").Exists())
+ require.False(t, gjson.GetBytes(normalized, "stream").Exists())
+ require.False(t, gjson.GetBytes(normalized, "prompt_cache_key").Exists())
+}
+
func TestOpenAIBuildUpstreamRequestOpenAIPassthroughPreservesCompactPath(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
@@ -1880,6 +2196,33 @@ func TestHandleSSEToJSON_CompletedEventReturnsJSON(t *testing.T) {
require.NotContains(t, rec.Body.String(), "data:")
}
+func TestHandleSSEToJSON_ReconstructsImageGenerationOutputItemDone(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ svc := &OpenAIGatewayService{cfg: &config.Config{}}
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ }
+ body := []byte(strings.Join([]string{
+ `data: {"type":"response.output_item.done","item":{"id":"ig_123","type":"image_generation_call","result":"aGVsbG8=","revised_prompt":"draw a cat","output_format":"png"}}`,
+ `data: {"type":"response.completed","response":{"id":"resp_img","model":"gpt-5.4","output":[],"usage":{"input_tokens":7,"output_tokens":9,"output_tokens_details":{"image_tokens":4}}}}`,
+ `data: [DONE]`,
+ }, "\n"))
+
+ usage, err := svc.handleSSEToJSON(resp, c, body, "gpt-5.4", "gpt-5.4")
+ require.NoError(t, err)
+ require.NotNil(t, usage)
+ require.Equal(t, 4, usage.ImageOutputTokens)
+ require.NotContains(t, rec.Body.String(), "data:")
+ require.Equal(t, "image_generation_call", gjson.Get(rec.Body.String(), "output.0.type").String())
+ require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "output.0.result").String())
+ require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "output.0.revised_prompt").String())
+}
+
func TestHandleSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go
new file mode 100644
index 00000000..4badcb1c
--- /dev/null
+++ b/backend/internal/service/openai_images.go
@@ -0,0 +1,1346 @@
+package service
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "io"
+ "mime"
+ "mime/multipart"
+ "net/http"
+ "net/textproto"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
+ "github.com/gin-gonic/gin"
+ "github.com/imroc/req/v3"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+const (
+ openAIImagesGenerationsEndpoint = "/v1/images/generations"
+ openAIImagesEditsEndpoint = "/v1/images/edits"
+
+ openAIImagesGenerationsURL = "https://api.openai.com/v1/images/generations"
+ openAIImagesEditsURL = "https://api.openai.com/v1/images/edits"
+
+ openAIChatGPTStartURL = "https://chatgpt.com/"
+ openAIChatGPTFilesURL = "https://chatgpt.com/backend-api/files"
+ openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
+ openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download
+ openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part
+ openAIImagesResponsesMainModel = "gpt-5.4-mini"
+)
+
+type OpenAIImagesCapability string
+
+const (
+ OpenAIImagesCapabilityBasic OpenAIImagesCapability = "images-basic"
+ OpenAIImagesCapabilityNative OpenAIImagesCapability = "images-native"
+)
+
+type OpenAIImagesUpload struct {
+ FieldName string
+ FileName string
+ ContentType string
+ Data []byte
+ Width int
+ Height int
+}
+
+type OpenAIImagesRequest struct {
+ Endpoint string
+ ContentType string
+ Multipart bool
+ Model string
+ ExplicitModel bool
+ Prompt string
+ Stream bool
+ N int
+ Size string
+ ExplicitSize bool
+ SizeTier string
+ ResponseFormat string
+ Quality string
+ Background string
+ OutputFormat string
+ Moderation string
+ InputFidelity string
+ Style string
+ OutputCompression *int
+ PartialImages *int
+ HasMask bool
+ HasNativeOptions bool
+ RequiredCapability OpenAIImagesCapability
+ InputImageURLs []string
+ MaskImageURL string
+ Uploads []OpenAIImagesUpload
+ MaskUpload *OpenAIImagesUpload
+ Body []byte
+ bodyHash string
+}
+
+func (r *OpenAIImagesRequest) IsEdits() bool {
+ return r != nil && r.Endpoint == openAIImagesEditsEndpoint
+}
+
+func (r *OpenAIImagesRequest) StickySessionSeed() string {
+ if r == nil {
+ return ""
+ }
+ parts := []string{
+ "openai-images",
+ strings.TrimSpace(r.Endpoint),
+ strings.TrimSpace(r.Model),
+ strings.TrimSpace(r.Size),
+ strings.TrimSpace(r.Prompt),
+ }
+ seed := strings.Join(parts, "|")
+ if strings.TrimSpace(r.Prompt) == "" && r.bodyHash != "" {
+ seed += "|body=" + r.bodyHash
+ }
+ return seed
+}
+
+func (s *OpenAIGatewayService) ParseOpenAIImagesRequest(c *gin.Context, body []byte) (*OpenAIImagesRequest, error) {
+ if c == nil || c.Request == nil {
+ return nil, fmt.Errorf("missing request context")
+ }
+ endpoint := normalizeOpenAIImagesEndpointPath(c.Request.URL.Path)
+ if endpoint == "" {
+ return nil, fmt.Errorf("unsupported images endpoint")
+ }
+
+ contentType := strings.TrimSpace(c.GetHeader("Content-Type"))
+ req := &OpenAIImagesRequest{
+ Endpoint: endpoint,
+ ContentType: contentType,
+ N: 1,
+ Body: body,
+ }
+ if len(body) > 0 {
+ sum := sha256.Sum256(body)
+ req.bodyHash = hex.EncodeToString(sum[:8])
+ }
+
+ mediaType, _, err := mime.ParseMediaType(contentType)
+ if err == nil && strings.EqualFold(mediaType, "multipart/form-data") {
+ req.Multipart = true
+ if parseErr := parseOpenAIImagesMultipartRequest(body, contentType, req); parseErr != nil {
+ return nil, parseErr
+ }
+ } else {
+ if len(body) == 0 {
+ return nil, fmt.Errorf("request body is empty")
+ }
+ if !gjson.ValidBytes(body) {
+ return nil, fmt.Errorf("failed to parse request body")
+ }
+ if parseErr := parseOpenAIImagesJSONRequest(body, req); parseErr != nil {
+ return nil, parseErr
+ }
+ }
+
+ applyOpenAIImagesDefaults(req)
+ if err := validateOpenAIImagesModel(req.Model); err != nil {
+ return nil, err
+ }
+ req.SizeTier = normalizeOpenAIImageSizeTier(req.Size)
+ req.RequiredCapability = classifyOpenAIImagesCapability(req)
+ return req, nil
+}
+
+func parseOpenAIImagesJSONRequest(body []byte, req *OpenAIImagesRequest) error {
+ if modelResult := gjson.GetBytes(body, "model"); modelResult.Exists() {
+ req.Model = strings.TrimSpace(modelResult.String())
+ req.ExplicitModel = req.Model != ""
+ }
+ req.Prompt = strings.TrimSpace(gjson.GetBytes(body, "prompt").String())
+
+ if streamResult := gjson.GetBytes(body, "stream"); streamResult.Exists() {
+ if streamResult.Type != gjson.True && streamResult.Type != gjson.False {
+ return fmt.Errorf("invalid stream field type")
+ }
+ req.Stream = streamResult.Bool()
+ }
+
+ if nResult := gjson.GetBytes(body, "n"); nResult.Exists() {
+ if nResult.Type != gjson.Number {
+ return fmt.Errorf("invalid n field type")
+ }
+ req.N = int(nResult.Int())
+ if req.N <= 0 {
+ return fmt.Errorf("n must be greater than 0")
+ }
+ }
+
+ if sizeResult := gjson.GetBytes(body, "size"); sizeResult.Exists() {
+ req.Size = strings.TrimSpace(sizeResult.String())
+ req.ExplicitSize = req.Size != ""
+ }
+ req.ResponseFormat = strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "response_format").String()))
+ req.Quality = strings.TrimSpace(gjson.GetBytes(body, "quality").String())
+ req.Background = strings.TrimSpace(gjson.GetBytes(body, "background").String())
+ req.OutputFormat = strings.TrimSpace(gjson.GetBytes(body, "output_format").String())
+ req.Moderation = strings.TrimSpace(gjson.GetBytes(body, "moderation").String())
+ req.InputFidelity = strings.TrimSpace(gjson.GetBytes(body, "input_fidelity").String())
+ req.Style = strings.TrimSpace(gjson.GetBytes(body, "style").String())
+ req.HasMask = gjson.GetBytes(body, "mask").Exists()
+ if outputCompression := gjson.GetBytes(body, "output_compression"); outputCompression.Exists() {
+ if outputCompression.Type != gjson.Number {
+ return fmt.Errorf("invalid output_compression field type")
+ }
+ v := int(outputCompression.Int())
+ req.OutputCompression = &v
+ }
+ if partialImages := gjson.GetBytes(body, "partial_images"); partialImages.Exists() {
+ if partialImages.Type != gjson.Number {
+ return fmt.Errorf("invalid partial_images field type")
+ }
+ v := int(partialImages.Int())
+ req.PartialImages = &v
+ }
+ if req.IsEdits() {
+ images := gjson.GetBytes(body, "images")
+ if images.Exists() {
+ if !images.IsArray() {
+ return fmt.Errorf("invalid images field type")
+ }
+ for _, item := range images.Array() {
+ if imageURL := strings.TrimSpace(item.Get("image_url").String()); imageURL != "" {
+ req.InputImageURLs = append(req.InputImageURLs, imageURL)
+ continue
+ }
+ if item.Get("file_id").Exists() {
+ return fmt.Errorf("images[].file_id is not supported (use images[].image_url instead)")
+ }
+ }
+ }
+ if maskImageURL := strings.TrimSpace(gjson.GetBytes(body, "mask.image_url").String()); maskImageURL != "" {
+ req.MaskImageURL = maskImageURL
+ req.HasMask = true
+ }
+ if gjson.GetBytes(body, "mask.file_id").Exists() {
+ return fmt.Errorf("mask.file_id is not supported (use mask.image_url instead)")
+ }
+ if len(req.InputImageURLs) == 0 {
+ return fmt.Errorf("images[].image_url is required")
+ }
+ }
+ req.HasNativeOptions = hasOpenAINativeImageOptions(func(path string) bool {
+ return gjson.GetBytes(body, path).Exists()
+ })
+ return nil
+}
+
+func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *OpenAIImagesRequest) error {
+ _, params, err := mime.ParseMediaType(contentType)
+ if err != nil {
+ return fmt.Errorf("invalid multipart content-type: %w", err)
+ }
+ boundary := strings.TrimSpace(params["boundary"])
+ if boundary == "" {
+ return fmt.Errorf("multipart boundary is required")
+ }
+
+ reader := multipart.NewReader(bytes.NewReader(body), boundary)
+ for {
+ part, err := reader.NextPart()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return fmt.Errorf("read multipart body: %w", err)
+ }
+ name := strings.TrimSpace(part.FormName())
+ if name == "" {
+ _ = part.Close()
+ continue
+ }
+
+ data, err := io.ReadAll(io.LimitReader(part, openAIImageMaxUploadPartSize))
+ _ = part.Close()
+ if err != nil {
+ return fmt.Errorf("read multipart field %s: %w", name, err)
+ }
+
+ fileName := strings.TrimSpace(part.FileName())
+ if fileName != "" {
+ partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
+ if name == "mask" && len(data) > 0 {
+ req.HasMask = true
+ width, height := parseOpenAIImageDimensions(part.Header)
+ maskUpload := OpenAIImagesUpload{
+ FieldName: name,
+ FileName: fileName,
+ ContentType: partContentType,
+ Data: data,
+ Width: width,
+ Height: height,
+ }
+ req.MaskUpload = &maskUpload
+ }
+ if name == "image" || strings.HasPrefix(name, "image[") {
+ width, height := parseOpenAIImageDimensions(part.Header)
+ req.Uploads = append(req.Uploads, OpenAIImagesUpload{
+ FieldName: name,
+ FileName: fileName,
+ ContentType: partContentType,
+ Data: data,
+ Width: width,
+ Height: height,
+ })
+ }
+ continue
+ }
+
+ value := strings.TrimSpace(string(data))
+ switch name {
+ case "model":
+ req.Model = value
+ req.ExplicitModel = value != ""
+ case "prompt":
+ req.Prompt = value
+ case "size":
+ req.Size = value
+ req.ExplicitSize = value != ""
+ case "response_format":
+ req.ResponseFormat = strings.ToLower(value)
+ case "stream":
+ parsed, err := strconv.ParseBool(value)
+ if err != nil {
+ return fmt.Errorf("invalid stream field value")
+ }
+ req.Stream = parsed
+ case "n":
+ n, err := strconv.Atoi(value)
+ if err != nil || n <= 0 {
+ return fmt.Errorf("n must be a positive integer")
+ }
+ req.N = n
+ case "quality":
+ req.Quality = value
+ req.HasNativeOptions = true
+ case "background":
+ req.Background = value
+ req.HasNativeOptions = true
+ case "output_format":
+ req.OutputFormat = value
+ req.HasNativeOptions = true
+ case "moderation":
+ req.Moderation = value
+ req.HasNativeOptions = true
+ case "input_fidelity":
+ req.InputFidelity = value
+ req.HasNativeOptions = true
+ case "style":
+ req.Style = value
+ req.HasNativeOptions = true
+ case "output_compression":
+ n, err := strconv.Atoi(value)
+ if err != nil {
+ return fmt.Errorf("invalid output_compression field value")
+ }
+ req.OutputCompression = &n
+ req.HasNativeOptions = true
+ case "partial_images":
+ n, err := strconv.Atoi(value)
+ if err != nil {
+ return fmt.Errorf("invalid partial_images field value")
+ }
+ req.PartialImages = &n
+ req.HasNativeOptions = true
+ default:
+ if isOpenAINativeImageOption(name) && value != "" {
+ req.HasNativeOptions = true
+ }
+ }
+ }
+
+ if len(req.Uploads) == 0 && req.IsEdits() {
+ return fmt.Errorf("image file is required")
+ }
+ return nil
+}
+
+func parseOpenAIImageDimensions(_ textproto.MIMEHeader) (int, int) {
+ return 0, 0
+}
+
+func applyOpenAIImagesDefaults(req *OpenAIImagesRequest) {
+ if req == nil {
+ return
+ }
+ if req.N <= 0 {
+ req.N = 1
+ }
+ if strings.TrimSpace(req.Model) != "" {
+ req.Model = strings.TrimSpace(req.Model)
+ return
+ }
+ req.Model = "gpt-image-2"
+}
+
+func isOpenAIImageGenerationModel(model string) bool {
+ return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "gpt-image-")
+}
+
+func validateOpenAIImagesModel(model string) error {
+ model = strings.TrimSpace(model)
+ if isOpenAIImageGenerationModel(model) {
+ return nil
+ }
+ if model == "" {
+ return fmt.Errorf("images endpoint requires an image model")
+ }
+ return fmt.Errorf("images endpoint requires an image model, got %q", model)
+}
+
+func normalizeOpenAIImagesEndpointPath(path string) string {
+ trimmed := strings.TrimSpace(path)
+ switch {
+ case strings.Contains(trimmed, "/images/generations"):
+ return openAIImagesGenerationsEndpoint
+ case strings.Contains(trimmed, "/images/edits"):
+ return openAIImagesEditsEndpoint
+ default:
+ return ""
+ }
+}
+
+func classifyOpenAIImagesCapability(req *OpenAIImagesRequest) OpenAIImagesCapability {
+ if req == nil {
+ return OpenAIImagesCapabilityNative
+ }
+ if req.ExplicitModel || req.ExplicitSize {
+ return OpenAIImagesCapabilityNative
+ }
+ model := strings.ToLower(strings.TrimSpace(req.Model))
+ if !strings.HasPrefix(model, "gpt-image-") {
+ return OpenAIImagesCapabilityNative
+ }
+ if req.Stream || req.N != 1 || req.HasMask || req.HasNativeOptions {
+ return OpenAIImagesCapabilityNative
+ }
+ if req.IsEdits() && !req.Multipart {
+ return OpenAIImagesCapabilityNative
+ }
+ if req.ResponseFormat != "" && req.ResponseFormat != "b64_json" {
+ return OpenAIImagesCapabilityNative
+ }
+ return OpenAIImagesCapabilityBasic
+}
+
+func hasOpenAINativeImageOptions(exists func(path string) bool) bool {
+ for _, path := range []string{
+ "background",
+ "quality",
+ "style",
+ "output_format",
+ "output_compression",
+ "moderation",
+ "input_fidelity",
+ "partial_images",
+ } {
+ if exists(path) {
+ return true
+ }
+ }
+ return false
+}
+
+func isOpenAINativeImageOption(name string) bool {
+ switch strings.TrimSpace(strings.ToLower(name)) {
+ case "background", "quality", "style", "output_format", "output_compression", "moderation", "input_fidelity", "partial_images":
+ return true
+ default:
+ return false
+ }
+}
+
+func normalizeOpenAIImageSizeTier(size string) string {
+ switch strings.ToLower(strings.TrimSpace(size)) {
+ case "1024x1024":
+ return "1K"
+ case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto":
+ return "2K"
+ default:
+ return "2K"
+ }
+}
+
+func (s *OpenAIGatewayService) ForwardImages(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ body []byte,
+ parsed *OpenAIImagesRequest,
+ channelMappedModel string,
+) (*OpenAIForwardResult, error) {
+ if parsed == nil {
+ return nil, fmt.Errorf("parsed images request is required")
+ }
+ switch account.Type {
+ case AccountTypeAPIKey:
+ return s.forwardOpenAIImagesAPIKey(ctx, c, account, body, parsed, channelMappedModel)
+ case AccountTypeOAuth:
+ return s.forwardOpenAIImagesOAuth(ctx, c, account, parsed, channelMappedModel)
+ default:
+ return nil, fmt.Errorf("unsupported account type: %s", account.Type)
+ }
+}
+
+func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ body []byte,
+ parsed *OpenAIImagesRequest,
+ channelMappedModel string,
+) (*OpenAIForwardResult, error) {
+ startTime := time.Now()
+ requestModel := strings.TrimSpace(parsed.Model)
+ if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
+ requestModel = mapped
+ }
+ if err := validateOpenAIImagesModel(requestModel); err != nil {
+ return nil, err
+ }
+ upstreamModel := account.GetMappedModel(requestModel)
+ if err := validateOpenAIImagesModel(upstreamModel); err != nil {
+ return nil, err
+ }
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[OpenAI] Images request routing request_model=%s upstream_model=%s endpoint=%s account_type=%s",
+ strings.TrimSpace(parsed.Model),
+ upstreamModel,
+ parsed.Endpoint,
+ account.Type,
+ )
+ forwardBody, forwardContentType, err := rewriteOpenAIImagesModel(body, parsed.ContentType, upstreamModel)
+ if err != nil {
+ return nil, err
+ }
+ if !parsed.Multipart {
+ setOpsUpstreamRequestBody(c, forwardBody)
+ }
+
+ token, _, err := s.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, err
+ }
+ upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint)
+ if err != nil {
+ return nil, err
+ }
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+ upstreamStart := time.Now()
+ resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
+ SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
+ if err != nil {
+ safeErr := sanitizeUpstreamErrorMessage(err.Error())
+ setOpsUpstreamError(c, 0, safeErr, "")
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: 0,
+ UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
+ Kind: "request_error",
+ Message: safeErr,
+ })
+ return nil, fmt.Errorf("upstream request failed: %s", safeErr)
+ }
+ if resp.StatusCode >= 400 {
+ respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ _ = resp.Body.Close()
+ resp.Body = io.NopCloser(bytes.NewReader(respBody))
+ upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
+ upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
+ if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: resp.StatusCode,
+ UpstreamRequestID: resp.Header.Get("x-request-id"),
+ UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
+ Kind: "failover",
+ Message: upstreamMsg,
+ })
+ s.handleFailoverSideEffects(ctx, resp, account)
+ return nil, &UpstreamFailoverError{
+ StatusCode: resp.StatusCode,
+ ResponseBody: respBody,
+ RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
+ }
+ }
+ return s.handleErrorResponse(ctx, resp, c, account, forwardBody)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ var usage OpenAIUsage
+ imageCount := parsed.N
+ var firstTokenMs *int
+ if parsed.Stream {
+ streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
+ if err != nil {
+ return nil, err
+ }
+ usage = streamUsage
+ imageCount = streamCount
+ firstTokenMs = ttft
+ } else {
+ nonStreamUsage, nonStreamCount, err := s.handleOpenAIImagesNonStreamingResponse(resp, c)
+ if err != nil {
+ return nil, err
+ }
+ usage = nonStreamUsage
+ if nonStreamCount > 0 {
+ imageCount = nonStreamCount
+ }
+ }
+ return &OpenAIForwardResult{
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: usage,
+ Model: requestModel,
+ UpstreamModel: upstreamModel,
+ Stream: parsed.Stream,
+ ResponseHeaders: resp.Header.Clone(),
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ ImageCount: imageCount,
+ ImageSize: parsed.SizeTier,
+ }, nil
+}
+
+func (s *OpenAIGatewayService) buildOpenAIImagesRequest(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ body []byte,
+ contentType string,
+ token string,
+ endpoint string,
+) (*http.Request, error) {
+ targetURL := openAIImagesGenerationsURL
+ if endpoint == openAIImagesEditsEndpoint {
+ targetURL = openAIImagesEditsURL
+ }
+ baseURL := account.GetOpenAIBaseURL()
+ if baseURL != "" {
+ validatedURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, err
+ }
+ targetURL = buildOpenAIImagesURL(validatedURL, endpoint)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Authorization", "Bearer "+token)
+ for key, values := range c.Request.Header {
+ if !openaiPassthroughAllowedHeaders[strings.ToLower(key)] {
+ continue
+ }
+ for _, value := range values {
+ req.Header.Add(key, value)
+ }
+ }
+ customUA := account.GetOpenAIUserAgent()
+ if customUA != "" {
+ req.Header.Set("User-Agent", customUA)
+ }
+ if strings.TrimSpace(contentType) != "" {
+ req.Header.Set("Content-Type", contentType)
+ }
+ return req, nil
+}
+
+func buildOpenAIImagesURL(base string, endpoint string) string {
+ normalized := strings.TrimRight(strings.TrimSpace(base), "/")
+ relative := strings.TrimPrefix(strings.TrimSpace(endpoint), "/v1")
+ if strings.HasSuffix(normalized, endpoint) || strings.HasSuffix(normalized, relative) {
+ return normalized
+ }
+ if strings.HasSuffix(normalized, "/v1") {
+ return normalized + relative
+ }
+ return normalized + endpoint
+}
+
+func rewriteOpenAIImagesModel(body []byte, contentType string, model string) ([]byte, string, error) {
+ model = strings.TrimSpace(model)
+ if model == "" {
+ return body, contentType, nil
+ }
+ mediaType, _, err := mime.ParseMediaType(contentType)
+ if err == nil && strings.EqualFold(mediaType, "multipart/form-data") {
+ rewrittenBody, rewrittenType, rewriteErr := rewriteOpenAIImagesMultipartModel(body, contentType, model)
+ return rewrittenBody, rewrittenType, rewriteErr
+ }
+ rewritten, err := sjson.SetBytes(body, "model", model)
+ if err != nil {
+ return nil, "", fmt.Errorf("rewrite image request model: %w", err)
+ }
+ return rewritten, contentType, nil
+}
+
+func rewriteOpenAIImagesMultipartModel(body []byte, contentType string, model string) ([]byte, string, error) {
+ _, params, err := mime.ParseMediaType(contentType)
+ if err != nil {
+ return nil, "", fmt.Errorf("parse multipart content-type: %w", err)
+ }
+ boundary := strings.TrimSpace(params["boundary"])
+ if boundary == "" {
+ return nil, "", fmt.Errorf("multipart boundary is required")
+ }
+
+ reader := multipart.NewReader(bytes.NewReader(body), boundary)
+ var buffer bytes.Buffer
+ writer := multipart.NewWriter(&buffer)
+ modelWritten := false
+
+ for {
+ part, err := reader.NextPart()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return nil, "", fmt.Errorf("read multipart body: %w", err)
+ }
+
+ formName := strings.TrimSpace(part.FormName())
+ partHeader := cloneMultipartHeader(part.Header)
+ target, err := writer.CreatePart(partHeader)
+ if err != nil {
+ _ = part.Close()
+ return nil, "", fmt.Errorf("create multipart part: %w", err)
+ }
+
+ if formName == "model" && part.FileName() == "" {
+ if _, err := target.Write([]byte(model)); err != nil {
+ _ = part.Close()
+ return nil, "", fmt.Errorf("rewrite multipart model: %w", err)
+ }
+ modelWritten = true
+ _ = part.Close()
+ continue
+ }
+ if _, err := io.Copy(target, part); err != nil {
+ _ = part.Close()
+ return nil, "", fmt.Errorf("copy multipart part: %w", err)
+ }
+ _ = part.Close()
+ }
+
+ if !modelWritten {
+ if err := writer.WriteField("model", model); err != nil {
+ return nil, "", fmt.Errorf("append multipart model field: %w", err)
+ }
+ }
+ if err := writer.Close(); err != nil {
+ return nil, "", fmt.Errorf("finalize multipart body: %w", err)
+ }
+ return buffer.Bytes(), writer.FormDataContentType(), nil
+}
+
+func cloneMultipartHeader(src textproto.MIMEHeader) textproto.MIMEHeader {
+ dst := make(textproto.MIMEHeader, len(src))
+ for key, values := range src {
+ copied := make([]string, len(values))
+ copy(copied, values)
+ dst[key] = copied
+ }
+ return dst
+}
+
+func (s *OpenAIGatewayService) handleOpenAIImagesNonStreamingResponse(resp *http.Response, c *gin.Context) (OpenAIUsage, int, error) {
+ body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
+ if err != nil {
+ return OpenAIUsage{}, 0, err
+ }
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
+ contentType := "application/json"
+ if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
+ if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
+ contentType = upstreamType
+ }
+ }
+ c.Data(resp.StatusCode, contentType, body)
+
+ usage, _ := extractOpenAIUsageFromJSONBytes(body)
+ return usage, extractOpenAIImageCountFromJSONBytes(body), nil
+}
+
+func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
+ resp *http.Response,
+ c *gin.Context,
+ startTime time.Time,
+) (OpenAIUsage, int, *int, error) {
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
+ contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
+ if contentType == "" {
+ contentType = "text/event-stream"
+ }
+ c.Status(resp.StatusCode)
+ c.Header("Content-Type", contentType)
+
+ flusher, ok := c.Writer.(http.Flusher)
+ if !ok {
+ return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
+ }
+
+ reader := bufio.NewReader(resp.Body)
+ usage := OpenAIUsage{}
+ imageCount := 0
+ var firstTokenMs *int
+
+ for {
+ line, err := reader.ReadBytes('\n')
+ if len(line) > 0 {
+ if firstTokenMs == nil {
+ ms := int(time.Since(startTime).Milliseconds())
+ firstTokenMs = &ms
+ }
+ if _, writeErr := c.Writer.Write(line); writeErr != nil {
+ return OpenAIUsage{}, 0, firstTokenMs, writeErr
+ }
+ flusher.Flush()
+
+ if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok && data != "" && data != "[DONE]" {
+ dataBytes := []byte(data)
+ mergeOpenAIUsage(&usage, dataBytes)
+ if count := extractOpenAIImageCountFromJSONBytes(dataBytes); count > imageCount {
+ imageCount = count
+ }
+ }
+ }
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return OpenAIUsage{}, 0, firstTokenMs, err
+ }
+ }
+ return usage, imageCount, firstTokenMs, nil
+}
+
+func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
+ if dst == nil {
+ return
+ }
+ if parsed, ok := extractOpenAIUsageFromJSONBytes(body); ok {
+ if parsed.InputTokens > 0 {
+ dst.InputTokens = parsed.InputTokens
+ }
+ if parsed.OutputTokens > 0 {
+ dst.OutputTokens = parsed.OutputTokens
+ }
+ if parsed.CacheReadInputTokens > 0 {
+ dst.CacheReadInputTokens = parsed.CacheReadInputTokens
+ }
+ if parsed.ImageOutputTokens > 0 {
+ dst.ImageOutputTokens = parsed.ImageOutputTokens
+ }
+ }
+}
+
+func extractOpenAIImageCountFromJSONBytes(body []byte) int {
+ if len(body) == 0 || !gjson.ValidBytes(body) {
+ return 0
+ }
+ data := gjson.GetBytes(body, "data")
+ if data.Exists() && data.IsArray() {
+ return len(data.Array())
+ }
+ return 0
+}
+
+type openAIImagePointerInfo struct {
+ Pointer string
+ DownloadURL string
+ B64JSON string
+ MimeType string
+ Prompt string
+}
+
+func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo {
+ if len(body) == 0 {
+ return nil
+ }
+ prompt := ""
+ for _, path := range []string{
+ "message.metadata.dalle.prompt",
+ "metadata.dalle.prompt",
+ "revised_prompt",
+ } {
+ if value := strings.TrimSpace(gjson.GetBytes(body, path).String()); value != "" {
+ prompt = value
+ break
+ }
+ }
+ matches := openAIImagePointerMatches(body)
+ out := make([]openAIImagePointerInfo, 0, len(matches))
+ for _, pointer := range matches {
+ out = append(out, openAIImagePointerInfo{Pointer: pointer, Prompt: prompt})
+ }
+ return mergeOpenAIImagePointerInfos(out, collectOpenAIImageInlineAssets(body, prompt))
+}
+
+func openAIImagePointerMatches(body []byte) []string {
+ raw := string(body)
+ matches := make([]string, 0, 4)
+ for _, prefix := range []string{"file-service://", "sediment://"} {
+ start := 0
+ for {
+ idx := strings.Index(raw[start:], prefix)
+ if idx < 0 {
+ break
+ }
+ idx += start
+ end := idx + len(prefix)
+ for end < len(raw) {
+ ch := raw[end]
+ if ch != '-' && ch != '_' &&
+ (ch < '0' || ch > '9') &&
+ (ch < 'a' || ch > 'z') &&
+ (ch < 'A' || ch > 'Z') {
+ break
+ }
+ end++
+ }
+ matches = append(matches, raw[idx:end])
+ start = end
+ }
+ }
+ return dedupeStrings(matches)
+}
+
+func mergeOpenAIImagePointerInfos(existing []openAIImagePointerInfo, next []openAIImagePointerInfo) []openAIImagePointerInfo {
+ if len(next) == 0 {
+ return existing
+ }
+ seen := make(map[string]openAIImagePointerInfo, len(existing)+len(next))
+ out := make([]openAIImagePointerInfo, 0, len(existing)+len(next))
+ for _, item := range existing {
+ if key := item.identityKey(); key != "" {
+ seen[key] = item
+ }
+ out = append(out, item)
+ }
+ for _, item := range next {
+ key := item.identityKey()
+ if key == "" {
+ continue
+ }
+ if existingItem, ok := seen[key]; ok {
+ merged := mergeOpenAIImagePointerInfo(existingItem, item)
+ if merged != existingItem {
+ for i := range out {
+ if out[i].identityKey() == key {
+ out[i] = merged
+ break
+ }
+ }
+ seen[key] = merged
+ }
+ continue
+ }
+ seen[key] = item
+ out = append(out, item)
+ }
+ return out
+}
+
+func (i openAIImagePointerInfo) identityKey() string {
+ switch {
+ case strings.TrimSpace(i.Pointer) != "":
+ return "pointer:" + strings.TrimSpace(i.Pointer)
+ case strings.TrimSpace(i.DownloadURL) != "":
+ return "download:" + strings.TrimSpace(i.DownloadURL)
+ case strings.TrimSpace(i.B64JSON) != "":
+ b64 := strings.TrimSpace(i.B64JSON)
+ if len(b64) > 64 {
+ b64 = b64[:64]
+ }
+ return "b64:" + b64
+ default:
+ return ""
+ }
+}
+
+func mergeOpenAIImagePointerInfo(existing, next openAIImagePointerInfo) openAIImagePointerInfo {
+ merged := existing
+ if strings.TrimSpace(merged.Pointer) == "" {
+ merged.Pointer = next.Pointer
+ }
+ if strings.TrimSpace(merged.DownloadURL) == "" {
+ merged.DownloadURL = next.DownloadURL
+ }
+ if strings.TrimSpace(merged.B64JSON) == "" {
+ merged.B64JSON = next.B64JSON
+ }
+ if strings.TrimSpace(merged.MimeType) == "" {
+ merged.MimeType = next.MimeType
+ }
+ if strings.TrimSpace(merged.Prompt) == "" {
+ merged.Prompt = next.Prompt
+ }
+ return merged
+}
+
+func resolveOpenAIImageBytes(
+ ctx context.Context,
+ client *req.Client,
+ headers http.Header,
+ conversationID string,
+ pointer openAIImagePointerInfo,
+) ([]byte, error) {
+ if normalized := normalizeOpenAIImageBase64(pointer.B64JSON); normalized != "" {
+ return base64.StdEncoding.DecodeString(normalized)
+ }
+ if downloadURL := strings.TrimSpace(pointer.DownloadURL); downloadURL != "" {
+ return downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
+ }
+ if strings.TrimSpace(pointer.Pointer) == "" {
+ return nil, fmt.Errorf("image asset is missing pointer, url, and base64 data")
+ }
+ downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer)
+ if err != nil {
+ return nil, err
+ }
+ return downloadOpenAIImageBytes(ctx, client, headers, downloadURL)
+}
+
+func normalizeOpenAIImageBase64(raw string) string {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return ""
+ }
+ if strings.HasPrefix(strings.ToLower(raw), "data:") {
+ if idx := strings.Index(raw, ","); idx >= 0 && idx+1 < len(raw) {
+ raw = raw[idx+1:]
+ }
+ }
+ raw = strings.TrimSpace(raw)
+ raw = strings.TrimRight(raw, "=") + strings.Repeat("=", (4-len(raw)%4)%4)
+ if raw == "" {
+ return ""
+ }
+ if _, err := base64.StdEncoding.DecodeString(raw); err != nil {
+ return ""
+ }
+ return raw
+}
+
+func collectOpenAIImageInlineAssets(body []byte, fallbackPrompt string) []openAIImagePointerInfo {
+ if len(body) == 0 || !gjson.ValidBytes(body) {
+ return nil
+ }
+ var decoded any
+ if err := json.Unmarshal(body, &decoded); err != nil {
+ return nil
+ }
+ var out []openAIImagePointerInfo
+ walkOpenAIImageInlineAssets(decoded, strings.TrimSpace(fallbackPrompt), &out)
+ return out
+}
+
+func walkOpenAIImageInlineAssets(node any, prompt string, out *[]openAIImagePointerInfo) {
+ switch value := node.(type) {
+ case map[string]any:
+ localPrompt := prompt
+ for _, key := range []string{"revised_prompt", "image_gen_title", "prompt"} {
+ if v, ok := value[key].(string); ok && strings.TrimSpace(v) != "" {
+ localPrompt = strings.TrimSpace(v)
+ break
+ }
+ }
+ item := openAIImagePointerInfo{
+ Prompt: localPrompt,
+ Pointer: firstNonEmptyString(value["asset_pointer"], value["pointer"]),
+ DownloadURL: firstNonEmptyString(value["download_url"], value["url"], value["image_url"]),
+ B64JSON: firstNonEmptyString(value["b64_json"], value["base64"], value["image_base64"]),
+ MimeType: firstNonEmptyString(value["mime_type"], value["mimeType"], value["content_type"]),
+ }
+ switch {
+ case strings.HasPrefix(strings.TrimSpace(item.Pointer), "file-service://"),
+ strings.HasPrefix(strings.TrimSpace(item.Pointer), "sediment://"),
+ isLikelyOpenAIImageDownloadURL(item.DownloadURL),
+ normalizeOpenAIImageBase64(item.B64JSON) != "":
+ *out = append(*out, item)
+ }
+ for _, child := range value {
+ walkOpenAIImageInlineAssets(child, localPrompt, out)
+ }
+ case []any:
+ for _, child := range value {
+ walkOpenAIImageInlineAssets(child, prompt, out)
+ }
+ }
+}
+
+func firstNonEmptyString(values ...any) string {
+ for _, value := range values {
+ if s, ok := value.(string); ok && strings.TrimSpace(s) != "" {
+ return strings.TrimSpace(s)
+ }
+ }
+ return ""
+}
+
+func isLikelyOpenAIImageDownloadURL(raw string) bool {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return false
+ }
+ if strings.HasPrefix(strings.ToLower(raw), "data:image/") {
+ return true
+ }
+ if !strings.HasPrefix(strings.ToLower(raw), "http://") && !strings.HasPrefix(strings.ToLower(raw), "https://") {
+ return false
+ }
+ lower := strings.ToLower(raw)
+ return strings.Contains(lower, "/download") ||
+ strings.Contains(lower, ".png") ||
+ strings.Contains(lower, ".jpg") ||
+ strings.Contains(lower, ".jpeg") ||
+ strings.Contains(lower, ".webp")
+}
+
+func fetchOpenAIImageDownloadURL(
+ ctx context.Context,
+ client *req.Client,
+ headers http.Header,
+ conversationID string,
+ pointer string,
+) (string, error) {
+ url := ""
+ allowConversationRetry := false
+ switch {
+ case strings.HasPrefix(pointer, "file-service://"):
+ fileID := strings.TrimPrefix(pointer, "file-service://")
+ url = fmt.Sprintf("%s/%s/download", openAIChatGPTFilesURL, fileID)
+ case strings.HasPrefix(pointer, "sediment://"):
+ attachmentID := strings.TrimPrefix(pointer, "sediment://")
+ url = fmt.Sprintf("https://chatgpt.com/backend-api/conversation/%s/attachment/%s/download", conversationID, attachmentID)
+ allowConversationRetry = true
+ default:
+ return "", fmt.Errorf("unsupported image pointer: %s", pointer)
+ }
+
+ var lastErr error
+ for attempt := 0; attempt < 8; attempt++ {
+ var result struct {
+ DownloadURL string `json:"download_url"`
+ }
+ resp, err := client.R().
+ SetContext(ctx).
+ SetHeaders(headerToMap(headers)).
+ SetSuccessResult(&result).
+ Get(url)
+ if err != nil {
+ lastErr = err
+ } else if resp.IsSuccessState() && strings.TrimSpace(result.DownloadURL) != "" {
+ return strings.TrimSpace(result.DownloadURL), nil
+ } else {
+ statusErr := newOpenAIImageStatusError(resp, "fetch image download url failed")
+ if !allowConversationRetry || !isOpenAIImageTransientConversationNotFoundError(statusErr) {
+ return "", statusErr
+ }
+ lastErr = statusErr
+ }
+ if attempt == 7 {
+ break
+ }
+ timer := time.NewTimer(750 * time.Millisecond)
+ select {
+ case <-ctx.Done():
+ if !timer.Stop() {
+ <-timer.C
+ }
+ return "", ctx.Err()
+ case <-timer.C:
+ }
+ }
+ if lastErr == nil {
+ lastErr = fmt.Errorf("fetch image download url failed")
+ }
+ return "", lastErr
+}
+
+func downloadOpenAIImageBytes(ctx context.Context, client *req.Client, headers http.Header, downloadURL string) ([]byte, error) {
+ request := client.R().
+ SetContext(ctx).
+ DisableAutoReadResponse()
+
+ if strings.HasPrefix(downloadURL, openAIChatGPTStartURL) {
+ downloadHeaders := cloneHTTPHeader(headers)
+ downloadHeaders.Set("Accept", "image/*,*/*;q=0.8")
+ downloadHeaders.Del("Content-Type")
+ request.SetHeaders(headerToMap(downloadHeaders))
+ } else {
+ userAgent := strings.TrimSpace(headers.Get("User-Agent"))
+ if userAgent == "" {
+ userAgent = openAIImageBackendUserAgent
+ }
+ request.SetHeader("User-Agent", userAgent)
+ }
+
+ resp, err := request.Get(downloadURL)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if resp != nil && resp.Body != nil {
+ _ = resp.Body.Close()
+ }
+ }()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return nil, newOpenAIImageStatusError(resp, "download image bytes failed")
+ }
+ return io.ReadAll(io.LimitReader(resp.Body, openAIImageMaxDownloadBytes))
+}
+
+type openAIImageStatusError struct {
+ StatusCode int
+ Message string
+ ResponseBody []byte
+ ResponseHeaders http.Header
+ RequestID string
+ URL string
+}
+
+func (e *openAIImageStatusError) Error() string {
+ if e == nil {
+ return "openai image backend request failed"
+ }
+ if e.Message != "" {
+ return e.Message
+ }
+ if e.StatusCode > 0 {
+ return fmt.Sprintf("openai image backend request failed: status %d", e.StatusCode)
+ }
+ return "openai image backend request failed"
+}
+
+func newOpenAIImageStatusError(resp *req.Response, fallback string) error {
+ if resp == nil {
+ if strings.TrimSpace(fallback) == "" {
+ fallback = "openai image backend request failed"
+ }
+ return fmt.Errorf("%s", fallback)
+ }
+
+ statusCode := resp.StatusCode
+ headers := http.Header(nil)
+ requestID := ""
+ requestURL := ""
+ body := []byte(nil)
+
+ if resp.Response != nil {
+ headers = resp.Header.Clone()
+ requestID = strings.TrimSpace(resp.Header.Get("x-request-id"))
+ if resp.Request != nil && resp.Request.URL != nil {
+ requestURL = resp.Request.URL.String()
+ }
+ if resp.Body != nil {
+ body, _ = io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ _ = resp.Body.Close()
+ }
+ }
+
+ message := sanitizeUpstreamErrorMessage(extractUpstreamErrorMessage(body))
+ if message == "" {
+ prefix := strings.TrimSpace(fallback)
+ if prefix == "" {
+ prefix = "openai image backend request failed"
+ }
+ message = fmt.Sprintf("%s: status %d", prefix, statusCode)
+ }
+
+ return &openAIImageStatusError{
+ StatusCode: statusCode,
+ Message: message,
+ ResponseBody: body,
+ ResponseHeaders: headers,
+ RequestID: requestID,
+ URL: requestURL,
+ }
+}
+
+func isOpenAIImageTransientConversationNotFoundError(err error) bool {
+ statusErr, ok := err.(*openAIImageStatusError)
+ if !ok || statusErr == nil || statusErr.StatusCode != http.StatusNotFound {
+ return false
+ }
+ msg := strings.ToLower(strings.TrimSpace(statusErr.Message))
+ if strings.Contains(msg, "conversation_not_found") {
+ return true
+ }
+ if strings.Contains(msg, "conversation") && strings.Contains(msg, "not found") {
+ return true
+ }
+ bodyMsg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(statusErr.ResponseBody)))
+ if strings.Contains(bodyMsg, "conversation_not_found") {
+ return true
+ }
+ return strings.Contains(bodyMsg, "conversation") && strings.Contains(bodyMsg, "not found")
+}
+
+func cloneHTTPHeader(src http.Header) http.Header {
+ dst := make(http.Header, len(src))
+ for key, values := range src {
+ copied := make([]string, len(values))
+ copy(copied, values)
+ dst[key] = copied
+ }
+ return dst
+}
+
+func headerToMap(header http.Header) map[string]string {
+ if len(header) == 0 {
+ return nil
+ }
+ result := make(map[string]string, len(header))
+ for key, values := range header {
+ if len(values) == 0 {
+ continue
+ }
+ result[key] = values[0]
+ }
+ return result
+}
+
+func dedupeStrings(values []string) []string {
+ if len(values) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(values))
+ out := make([]string, 0, len(values))
+ for _, value := range values {
+ if _, ok := seen[value]; ok {
+ continue
+ }
+ seen[value] = struct{}{}
+ out = append(out, value)
+ }
+ return out
+}
diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go
new file mode 100644
index 00000000..64d995e1
--- /dev/null
+++ b/backend/internal/service/openai_images_responses.go
@@ -0,0 +1,853 @@
+package service
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/base64"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
+ "github.com/gin-gonic/gin"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+type openAIResponsesImageResult struct {
+ Result string
+ RevisedPrompt string
+ OutputFormat string
+ Size string
+ Background string
+ Quality string
+ Model string
+}
+
+func openAIResponsesImageResultKey(itemID string, result openAIResponsesImageResult) string {
+ if strings.TrimSpace(result.Result) != "" {
+ return strings.TrimSpace(result.OutputFormat) + "|" + strings.TrimSpace(result.Result)
+ }
+ return "item:" + strings.TrimSpace(itemID)
+}
+
+func appendOpenAIResponsesImageResultDedup(results *[]openAIResponsesImageResult, seen map[string]struct{}, itemID string, result openAIResponsesImageResult) bool {
+ if results == nil {
+ return false
+ }
+ key := openAIResponsesImageResultKey(itemID, result)
+ if key != "" {
+ if _, exists := seen[key]; exists {
+ return false
+ }
+ seen[key] = struct{}{}
+ }
+ *results = append(*results, result)
+ return true
+}
+
+func mergeOpenAIResponsesImageMeta(dst *openAIResponsesImageResult, src openAIResponsesImageResult) {
+ if dst == nil {
+ return
+ }
+ if trimmed := strings.TrimSpace(src.OutputFormat); trimmed != "" {
+ dst.OutputFormat = trimmed
+ }
+ if trimmed := strings.TrimSpace(src.Size); trimmed != "" {
+ dst.Size = trimmed
+ }
+ if trimmed := strings.TrimSpace(src.Background); trimmed != "" {
+ dst.Background = trimmed
+ }
+ if trimmed := strings.TrimSpace(src.Quality); trimmed != "" {
+ dst.Quality = trimmed
+ }
+ if trimmed := strings.TrimSpace(src.Model); trimmed != "" {
+ dst.Model = trimmed
+ }
+}
+
+func extractOpenAIResponsesImageMetaFromLifecycleEvent(payload []byte) (openAIResponsesImageResult, int64, bool) {
+ switch gjson.GetBytes(payload, "type").String() {
+ case "response.created", "response.in_progress", "response.completed":
+ default:
+ return openAIResponsesImageResult{}, 0, false
+ }
+
+ response := gjson.GetBytes(payload, "response")
+ if !response.Exists() {
+ return openAIResponsesImageResult{}, 0, false
+ }
+
+ meta := openAIResponsesImageResult{
+ OutputFormat: strings.TrimSpace(response.Get("tools.0.output_format").String()),
+ Size: strings.TrimSpace(response.Get("tools.0.size").String()),
+ Background: strings.TrimSpace(response.Get("tools.0.background").String()),
+ Quality: strings.TrimSpace(response.Get("tools.0.quality").String()),
+ Model: strings.TrimSpace(response.Get("tools.0.model").String()),
+ }
+ return meta, response.Get("created_at").Int(), true
+}
+
+func buildOpenAIImagesStreamPartialPayload(
+ eventType string,
+ b64 string,
+ partialImageIndex int64,
+ responseFormat string,
+ createdAt int64,
+ meta openAIResponsesImageResult,
+) []byte {
+ if createdAt <= 0 {
+ createdAt = time.Now().Unix()
+ }
+
+ payload := []byte(`{"type":"","created_at":0,"partial_image_index":0,"b64_json":""}`)
+ payload, _ = sjson.SetBytes(payload, "type", eventType)
+ payload, _ = sjson.SetBytes(payload, "created_at", createdAt)
+ payload, _ = sjson.SetBytes(payload, "partial_image_index", partialImageIndex)
+ payload, _ = sjson.SetBytes(payload, "b64_json", b64)
+ if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
+ payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(meta.OutputFormat)+";base64,"+b64)
+ }
+ if meta.Background != "" {
+ payload, _ = sjson.SetBytes(payload, "background", meta.Background)
+ }
+ if meta.OutputFormat != "" {
+ payload, _ = sjson.SetBytes(payload, "output_format", meta.OutputFormat)
+ }
+ if meta.Quality != "" {
+ payload, _ = sjson.SetBytes(payload, "quality", meta.Quality)
+ }
+ if meta.Size != "" {
+ payload, _ = sjson.SetBytes(payload, "size", meta.Size)
+ }
+ if meta.Model != "" {
+ payload, _ = sjson.SetBytes(payload, "model", meta.Model)
+ }
+ return payload
+}
+
+func buildOpenAIImagesStreamCompletedPayload(
+ eventType string,
+ img openAIResponsesImageResult,
+ responseFormat string,
+ createdAt int64,
+ usageRaw []byte,
+) []byte {
+ if createdAt <= 0 {
+ createdAt = time.Now().Unix()
+ }
+
+ payload := []byte(`{"type":"","created_at":0,"b64_json":""}`)
+ payload, _ = sjson.SetBytes(payload, "type", eventType)
+ payload, _ = sjson.SetBytes(payload, "created_at", createdAt)
+ payload, _ = sjson.SetBytes(payload, "b64_json", img.Result)
+ if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
+ payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result)
+ }
+ if img.Background != "" {
+ payload, _ = sjson.SetBytes(payload, "background", img.Background)
+ }
+ if img.OutputFormat != "" {
+ payload, _ = sjson.SetBytes(payload, "output_format", img.OutputFormat)
+ }
+ if img.Quality != "" {
+ payload, _ = sjson.SetBytes(payload, "quality", img.Quality)
+ }
+ if img.Size != "" {
+ payload, _ = sjson.SetBytes(payload, "size", img.Size)
+ }
+ if img.Model != "" {
+ payload, _ = sjson.SetBytes(payload, "model", img.Model)
+ }
+ if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) {
+ payload, _ = sjson.SetRawBytes(payload, "usage", usageRaw)
+ }
+ return payload
+}
+
+func openAIImageOutputMIMEType(outputFormat string) string {
+ if outputFormat == "" {
+ return "image/png"
+ }
+ if strings.Contains(outputFormat, "/") {
+ return outputFormat
+ }
+ switch strings.ToLower(strings.TrimSpace(outputFormat)) {
+ case "png":
+ return "image/png"
+ case "jpg", "jpeg":
+ return "image/jpeg"
+ case "webp":
+ return "image/webp"
+ default:
+ return "image/png"
+ }
+}
+
+func openAIImageUploadToDataURL(upload OpenAIImagesUpload) (string, error) {
+ if len(upload.Data) == 0 {
+ return "", fmt.Errorf("upload %q is empty", strings.TrimSpace(upload.FileName))
+ }
+ contentType := strings.TrimSpace(upload.ContentType)
+ if contentType == "" {
+ contentType = http.DetectContentType(upload.Data)
+ }
+ return "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(upload.Data), nil
+}
+
+func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel string) ([]byte, error) {
+ if parsed == nil {
+ return nil, fmt.Errorf("parsed images request is required")
+ }
+ prompt := strings.TrimSpace(parsed.Prompt)
+ if prompt == "" {
+ return nil, fmt.Errorf("prompt is required")
+ }
+
+ inputImages := make([]string, 0, len(parsed.InputImageURLs)+len(parsed.Uploads))
+ for _, imageURL := range parsed.InputImageURLs {
+ if trimmed := strings.TrimSpace(imageURL); trimmed != "" {
+ inputImages = append(inputImages, trimmed)
+ }
+ }
+ for _, upload := range parsed.Uploads {
+ dataURL, err := openAIImageUploadToDataURL(upload)
+ if err != nil {
+ return nil, err
+ }
+ inputImages = append(inputImages, dataURL)
+ }
+ if parsed.IsEdits() && len(inputImages) == 0 {
+ return nil, fmt.Errorf("image input is required")
+ }
+
+ req := []byte(`{"instructions":"","stream":true,"reasoning":{"effort":"medium","summary":"auto"},"parallel_tool_calls":true,"include":["reasoning.encrypted_content"],"model":"","store":false,"tool_choice":{"type":"image_generation"}}`)
+ req, _ = sjson.SetBytes(req, "model", openAIImagesResponsesMainModel)
+
+ input := []byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`)
+ input, _ = sjson.SetBytes(input, "0.content.0.text", prompt)
+ for index, imageURL := range inputImages {
+ part := []byte(`{"type":"input_image","image_url":""}`)
+ part, _ = sjson.SetBytes(part, "image_url", imageURL)
+ input, _ = sjson.SetRawBytes(input, fmt.Sprintf("0.content.%d", index+1), part)
+ }
+ req, _ = sjson.SetRawBytes(req, "input", input)
+
+ action := "generate"
+ if parsed.IsEdits() {
+ action = "edit"
+ }
+ tool := []byte(`{"type":"image_generation","action":"","model":""}`)
+ tool, _ = sjson.SetBytes(tool, "action", action)
+ tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel))
+
+ for _, field := range []struct {
+ path string
+ value string
+ }{
+ {path: "size", value: parsed.Size},
+ {path: "quality", value: parsed.Quality},
+ {path: "background", value: parsed.Background},
+ {path: "output_format", value: parsed.OutputFormat},
+ {path: "moderation", value: parsed.Moderation},
+ {path: "style", value: parsed.Style},
+ } {
+ if trimmed := strings.TrimSpace(field.value); trimmed != "" {
+ tool, _ = sjson.SetBytes(tool, field.path, trimmed)
+ }
+ }
+ if parsed.OutputCompression != nil {
+ tool, _ = sjson.SetBytes(tool, "output_compression", *parsed.OutputCompression)
+ }
+ if parsed.PartialImages != nil {
+ tool, _ = sjson.SetBytes(tool, "partial_images", *parsed.PartialImages)
+ }
+
+ maskImageURL := strings.TrimSpace(parsed.MaskImageURL)
+ if parsed.MaskUpload != nil {
+ dataURL, err := openAIImageUploadToDataURL(*parsed.MaskUpload)
+ if err != nil {
+ return nil, err
+ }
+ maskImageURL = dataURL
+ }
+ if maskImageURL != "" {
+ tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", maskImageURL)
+ }
+
+ req, _ = sjson.SetRawBytes(req, "tools", []byte(`[]`))
+ req, _ = sjson.SetRawBytes(req, "tools.-1", tool)
+ return req, nil
+}
+
+func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) {
+ if gjson.GetBytes(payload, "type").String() != "response.completed" {
+ return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type")
+ }
+
+ createdAt := gjson.GetBytes(payload, "response.created_at").Int()
+ if createdAt <= 0 {
+ createdAt = time.Now().Unix()
+ }
+
+ var (
+ results []openAIResponsesImageResult
+ firstMeta openAIResponsesImageResult
+ )
+ output := gjson.GetBytes(payload, "response.output")
+ if output.IsArray() {
+ for _, item := range output.Array() {
+ if item.Get("type").String() != "image_generation_call" {
+ continue
+ }
+ result := strings.TrimSpace(item.Get("result").String())
+ if result == "" {
+ continue
+ }
+ entry := openAIResponsesImageResult{
+ Result: result,
+ RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()),
+ OutputFormat: strings.TrimSpace(item.Get("output_format").String()),
+ Size: strings.TrimSpace(item.Get("size").String()),
+ Background: strings.TrimSpace(item.Get("background").String()),
+ Quality: strings.TrimSpace(item.Get("quality").String()),
+ }
+ if len(results) == 0 {
+ firstMeta = entry
+ }
+ results = append(results, entry)
+ }
+ }
+
+ var usageRaw []byte
+ if usage := gjson.GetBytes(payload, "response.tool_usage.image_gen"); usage.Exists() && usage.IsObject() {
+ usageRaw = []byte(usage.Raw)
+ }
+ return results, createdAt, usageRaw, firstMeta, nil
+}
+
+func extractOpenAIImageFromResponsesOutputItemDone(payload []byte) (openAIResponsesImageResult, string, bool, error) {
+ if gjson.GetBytes(payload, "type").String() != "response.output_item.done" {
+ return openAIResponsesImageResult{}, "", false, fmt.Errorf("unexpected event type")
+ }
+
+ item := gjson.GetBytes(payload, "item")
+ if !item.Exists() || item.Get("type").String() != "image_generation_call" {
+ return openAIResponsesImageResult{}, "", false, nil
+ }
+
+ result := strings.TrimSpace(item.Get("result").String())
+ if result == "" {
+ return openAIResponsesImageResult{}, "", false, nil
+ }
+
+ entry := openAIResponsesImageResult{
+ Result: result,
+ RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()),
+ OutputFormat: strings.TrimSpace(item.Get("output_format").String()),
+ Size: strings.TrimSpace(item.Get("size").String()),
+ Background: strings.TrimSpace(item.Get("background").String()),
+ Quality: strings.TrimSpace(item.Get("quality").String()),
+ }
+ return entry, strings.TrimSpace(item.Get("id").String()), true, nil
+}
+
+func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, bool, error) {
+ var (
+ fallbackResults []openAIResponsesImageResult
+ fallbackSeen = make(map[string]struct{})
+ createdAt int64
+ usageRaw []byte
+ foundFinal bool
+ responseMeta openAIResponsesImageResult
+ )
+
+ for _, line := range bytes.Split(body, []byte("\n")) {
+ line = bytes.TrimRight(line, "\r")
+ data, ok := extractOpenAISSEDataLine(string(line))
+ if !ok || data == "" || data == "[DONE]" {
+ continue
+ }
+ payload := []byte(data)
+ if !gjson.ValidBytes(payload) {
+ continue
+ }
+ if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok {
+ mergeOpenAIResponsesImageMeta(&responseMeta, meta)
+ if eventCreatedAt > 0 {
+ createdAt = eventCreatedAt
+ }
+ }
+
+ switch gjson.GetBytes(payload, "type").String() {
+ case "response.output_item.done":
+ result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload)
+ if err != nil {
+ return nil, 0, nil, openAIResponsesImageResult{}, false, err
+ }
+ if ok {
+ mergeOpenAIResponsesImageMeta(&result, responseMeta)
+ appendOpenAIResponsesImageResultDedup(&fallbackResults, fallbackSeen, itemID, result)
+ }
+ case "response.completed":
+ results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload)
+ if err != nil {
+ return nil, 0, nil, openAIResponsesImageResult{}, false, err
+ }
+ foundFinal = true
+ if completedAt > 0 {
+ createdAt = completedAt
+ }
+ if len(completedUsageRaw) > 0 {
+ usageRaw = completedUsageRaw
+ }
+ if len(results) > 0 {
+ mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
+ return results, createdAt, usageRaw, firstMeta, true, nil
+ }
+ if len(fallbackResults) > 0 {
+ firstMeta = fallbackResults[0]
+ mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
+ return fallbackResults, createdAt, usageRaw, firstMeta, true, nil
+ }
+ }
+ }
+
+ if len(fallbackResults) > 0 {
+ firstMeta := fallbackResults[0]
+ mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
+ return fallbackResults, createdAt, usageRaw, firstMeta, foundFinal, nil
+ }
+ return nil, createdAt, usageRaw, openAIResponsesImageResult{}, foundFinal, nil
+}
+
+func buildOpenAIImagesAPIResponse(
+ results []openAIResponsesImageResult,
+ createdAt int64,
+ usageRaw []byte,
+ firstMeta openAIResponsesImageResult,
+ responseFormat string,
+) ([]byte, error) {
+ if createdAt <= 0 {
+ createdAt = time.Now().Unix()
+ }
+ out := []byte(`{"created":0,"data":[]}`)
+ out, _ = sjson.SetBytes(out, "created", createdAt)
+
+ format := strings.ToLower(strings.TrimSpace(responseFormat))
+ if format == "" {
+ format = "b64_json"
+ }
+ for _, img := range results {
+ item := []byte(`{}`)
+ if format == "url" {
+ item, _ = sjson.SetBytes(item, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result)
+ } else {
+ item, _ = sjson.SetBytes(item, "b64_json", img.Result)
+ }
+ if img.RevisedPrompt != "" {
+ item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt)
+ }
+ out, _ = sjson.SetRawBytes(out, "data.-1", item)
+ }
+ if firstMeta.Background != "" {
+ out, _ = sjson.SetBytes(out, "background", firstMeta.Background)
+ }
+ if firstMeta.OutputFormat != "" {
+ out, _ = sjson.SetBytes(out, "output_format", firstMeta.OutputFormat)
+ }
+ if firstMeta.Quality != "" {
+ out, _ = sjson.SetBytes(out, "quality", firstMeta.Quality)
+ }
+ if firstMeta.Size != "" {
+ out, _ = sjson.SetBytes(out, "size", firstMeta.Size)
+ }
+ if firstMeta.Model != "" {
+ out, _ = sjson.SetBytes(out, "model", firstMeta.Model)
+ }
+ if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) {
+ out, _ = sjson.SetRawBytes(out, "usage", usageRaw)
+ }
+ return out, nil
+}
+
+func openAIImagesStreamPrefix(parsed *OpenAIImagesRequest) string {
+ if parsed != nil && parsed.IsEdits() {
+ return "image_edit"
+ }
+ return "image_generation"
+}
+
+func buildOpenAIImagesStreamErrorBody(message string) []byte {
+ body := []byte(`{"type":"error","error":{"type":"upstream_error","message":""}}`)
+ if strings.TrimSpace(message) == "" {
+ message = "upstream request failed"
+ }
+ body, _ = sjson.SetBytes(body, "error.message", message)
+ return body
+}
+
+func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flusher http.Flusher, eventName string, payload []byte) error {
+ if strings.TrimSpace(eventName) != "" {
+ if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil {
+ return err
+ }
+ }
+ if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil {
+ return err
+ }
+ flusher.Flush()
+ return nil
+}
+
+func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
+ resp *http.Response,
+ c *gin.Context,
+ responseFormat string,
+ fallbackModel string,
+) (OpenAIUsage, int, error) {
+ body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
+ if err != nil {
+ return OpenAIUsage{}, 0, err
+ }
+
+ var usage OpenAIUsage
+ for _, line := range bytes.Split(body, []byte("\n")) {
+ line = bytes.TrimRight(line, "\r")
+ data, ok := extractOpenAISSEDataLine(string(line))
+ if !ok || data == "" || data == "[DONE]" {
+ continue
+ }
+ dataBytes := []byte(data)
+ s.parseSSEUsageBytes(dataBytes, &usage)
+ }
+ results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body)
+ if err != nil {
+ return OpenAIUsage{}, 0, err
+ }
+ if len(results) == 0 {
+ return OpenAIUsage{}, 0, fmt.Errorf("upstream did not return image output")
+ }
+ if strings.TrimSpace(firstMeta.Model) == "" {
+ firstMeta.Model = strings.TrimSpace(fallbackModel)
+ }
+
+ responseBody, err := buildOpenAIImagesAPIResponse(results, createdAt, usageRaw, firstMeta, responseFormat)
+ if err != nil {
+ return OpenAIUsage{}, 0, err
+ }
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
+ c.Data(resp.StatusCode, "application/json; charset=utf-8", responseBody)
+ return usage, len(results), nil
+}
+
+func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
+ resp *http.Response,
+ c *gin.Context,
+ startTime time.Time,
+ responseFormat string,
+ streamPrefix string,
+ fallbackModel string,
+) (OpenAIUsage, int, *int, error) {
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
+ c.Header("Content-Type", "text/event-stream")
+ c.Header("Cache-Control", "no-cache")
+ c.Header("Connection", "keep-alive")
+ c.Status(resp.StatusCode)
+
+ flusher, ok := c.Writer.(http.Flusher)
+ if !ok {
+ return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
+ }
+
+ format := strings.ToLower(strings.TrimSpace(responseFormat))
+ if format == "" {
+ format = "b64_json"
+ }
+
+ reader := bufio.NewReader(resp.Body)
+ usage := OpenAIUsage{}
+ imageCount := 0
+ var firstTokenMs *int
+ emitted := make(map[string]struct{})
+ pendingResults := make([]openAIResponsesImageResult, 0, 1)
+ pendingSeen := make(map[string]struct{})
+ streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)}
+ var createdAt int64
+
+ for {
+ line, err := reader.ReadBytes('\n')
+ if len(line) > 0 {
+ trimmedLine := strings.TrimRight(string(line), "\r\n")
+ data, ok := extractOpenAISSEDataLine(trimmedLine)
+ if ok && data != "" && data != "[DONE]" {
+ if firstTokenMs == nil {
+ ms := int(time.Since(startTime).Milliseconds())
+ firstTokenMs = &ms
+ }
+ dataBytes := []byte(data)
+ s.parseSSEUsageBytes(dataBytes, &usage)
+ if gjson.ValidBytes(dataBytes) {
+ if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
+ mergeOpenAIResponsesImageMeta(&streamMeta, meta)
+ if eventCreatedAt > 0 {
+ createdAt = eventCreatedAt
+ }
+ }
+ switch gjson.GetBytes(dataBytes, "type").String() {
+ case "response.image_generation_call.partial_image":
+ b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
+ if b64 != "" {
+ eventName := streamPrefix + ".partial_image"
+ partialMeta := streamMeta
+ mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
+ OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
+ Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
+ })
+ payload := buildOpenAIImagesStreamPartialPayload(
+ eventName,
+ b64,
+ gjson.GetBytes(dataBytes, "partial_image_index").Int(),
+ format,
+ createdAt,
+ partialMeta,
+ )
+ if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
+ return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
+ }
+ }
+ case "response.output_item.done":
+ img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
+ if extractErr != nil {
+ _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
+ return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
+ }
+ if !ok {
+ break
+ }
+ mergeOpenAIResponsesImageMeta(&streamMeta, img)
+ mergeOpenAIResponsesImageMeta(&img, streamMeta)
+ key := openAIResponsesImageResultKey(itemID, img)
+ if _, exists := emitted[key]; exists {
+ break
+ }
+ if _, exists := pendingSeen[key]; exists {
+ break
+ }
+ pendingSeen[key] = struct{}{}
+ pendingResults = append(pendingResults, img)
+ case "response.completed":
+ results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
+ if extractErr != nil {
+ _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
+ return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
+ }
+ mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
+ finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
+ finalSeen := make(map[string]struct{})
+ for _, img := range results {
+ mergeOpenAIResponsesImageMeta(&img, streamMeta)
+ appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
+ }
+ for _, img := range pendingResults {
+ mergeOpenAIResponsesImageMeta(&img, streamMeta)
+ appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
+ }
+ if len(finalResults) == 0 {
+ err = fmt.Errorf("upstream did not return image output")
+ _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
+ return OpenAIUsage{}, imageCount, firstTokenMs, err
+ }
+ eventName := streamPrefix + ".completed"
+ for _, img := range finalResults {
+ key := openAIResponsesImageResultKey("", img)
+ if _, exists := emitted[key]; exists {
+ continue
+ }
+ payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
+ if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
+ return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
+ }
+ emitted[key] = struct{}{}
+ }
+ imageCount = len(emitted)
+ return usage, imageCount, firstTokenMs, nil
+ }
+ }
+ }
+ }
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
+ return OpenAIUsage{}, imageCount, firstTokenMs, err
+ }
+ }
+
+ if imageCount > 0 {
+ return usage, imageCount, firstTokenMs, nil
+ }
+ if len(pendingResults) > 0 {
+ eventName := streamPrefix + ".completed"
+ for _, img := range pendingResults {
+ mergeOpenAIResponsesImageMeta(&img, streamMeta)
+ key := openAIResponsesImageResultKey("", img)
+ if _, exists := emitted[key]; exists {
+ continue
+ }
+ payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
+ if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
+ return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
+ }
+ emitted[key] = struct{}{}
+ }
+ imageCount = len(emitted)
+ return usage, imageCount, firstTokenMs, nil
+ }
+
+ streamErr := fmt.Errorf("stream disconnected before image generation completed")
+ _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
+ return OpenAIUsage{}, imageCount, firstTokenMs, streamErr
+}
+
+func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ parsed *OpenAIImagesRequest,
+ channelMappedModel string,
+) (*OpenAIForwardResult, error) {
+ startTime := time.Now()
+ requestModel := strings.TrimSpace(parsed.Model)
+ if mapped := strings.TrimSpace(channelMappedModel); mapped != "" {
+ requestModel = mapped
+ }
+ if requestModel == "" {
+ requestModel = "gpt-image-2"
+ }
+ if err := validateOpenAIImagesModel(requestModel); err != nil {
+ return nil, err
+ }
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d",
+ requestModel,
+ parsed.Endpoint,
+ account.Type,
+ len(parsed.Uploads),
+ )
+ if parsed.N > 1 {
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s",
+ parsed.N,
+ requestModel,
+ parsed.Endpoint,
+ )
+ }
+
+ token, _, err := s.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, err
+ }
+
+ responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, requestModel)
+ if err != nil {
+ return nil, err
+ }
+ setOpsUpstreamRequestBody(c, responsesBody)
+
+ upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
+ if err != nil {
+ return nil, err
+ }
+ upstreamReq.Header.Set("Content-Type", "application/json")
+ upstreamReq.Header.Set("Accept", "text/event-stream")
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+ upstreamStart := time.Now()
+ resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
+ SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
+ if err != nil {
+ safeErr := sanitizeUpstreamErrorMessage(err.Error())
+ setOpsUpstreamError(c, 0, safeErr, "")
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: 0,
+ UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
+ Kind: "request_error",
+ Message: safeErr,
+ })
+ return nil, fmt.Errorf("upstream request failed: %s", safeErr)
+ }
+ if resp.StatusCode >= 400 {
+ respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ _ = resp.Body.Close()
+ resp.Body = io.NopCloser(bytes.NewReader(respBody))
+ upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
+ upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
+ if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: resp.StatusCode,
+ UpstreamRequestID: resp.Header.Get("x-request-id"),
+ UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()),
+ Kind: "failover",
+ Message: upstreamMsg,
+ })
+ s.handleFailoverSideEffects(ctx, resp, account)
+ return nil, &UpstreamFailoverError{
+ StatusCode: resp.StatusCode,
+ ResponseBody: respBody,
+ RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
+ }
+ }
+ return s.handleErrorResponse(ctx, resp, c, account, responsesBody)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ var (
+ usage OpenAIUsage
+ imageCount int
+ firstTokenMs *int
+ )
+ if parsed.Stream {
+ usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ usage, imageCount, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat, requestModel)
+ if err != nil {
+ return nil, err
+ }
+ }
+ if imageCount <= 0 {
+ imageCount = parsed.N
+ }
+ return &OpenAIForwardResult{
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: usage,
+ Model: requestModel,
+ UpstreamModel: requestModel,
+ Stream: parsed.Stream,
+ ResponseHeaders: resp.Header.Clone(),
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ ImageCount: imageCount,
+ ImageSize: parsed.SizeTier,
+ }, nil
+}
diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go
new file mode 100644
index 00000000..47113d4d
--- /dev/null
+++ b/backend/internal/service/openai_images_test.go
@@ -0,0 +1,856 @@
+package service
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "net/textproto"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Equal(t, "/v1/images/generations", parsed.Endpoint)
+ require.Equal(t, "gpt-image-2", parsed.Model)
+ require.Equal(t, "draw a cat", parsed.Prompt)
+ require.True(t, parsed.Stream)
+ require.Equal(t, "1024x1024", parsed.Size)
+ require.Equal(t, "1K", parsed.SizeTier)
+ require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
+ require.False(t, parsed.Multipart)
+}
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ require.NoError(t, writer.WriteField("model", "gpt-image-2"))
+ require.NoError(t, writer.WriteField("prompt", "replace background"))
+ require.NoError(t, writer.WriteField("size", "1536x1024"))
+ part, err := writer.CreateFormFile("image", "source.png")
+ require.NoError(t, err)
+ _, err = part.Write([]byte("fake-image-bytes"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Close())
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Equal(t, "/v1/images/edits", parsed.Endpoint)
+ require.True(t, parsed.Multipart)
+ require.Equal(t, "gpt-image-2", parsed.Model)
+ require.Equal(t, "replace background", parsed.Prompt)
+ require.Equal(t, "1536x1024", parsed.Size)
+ require.Equal(t, "2K", parsed.SizeTier)
+ require.Len(t, parsed.Uploads, 1)
+ require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
+}
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ require.NoError(t, writer.WriteField("model", "gpt-image-2"))
+ require.NoError(t, writer.WriteField("prompt", "replace foreground"))
+ require.NoError(t, writer.WriteField("output_format", "png"))
+ require.NoError(t, writer.WriteField("input_fidelity", "high"))
+ require.NoError(t, writer.WriteField("output_compression", "80"))
+ require.NoError(t, writer.WriteField("partial_images", "2"))
+
+ imageHeader := make(textproto.MIMEHeader)
+ imageHeader.Set("Content-Disposition", `form-data; name="image"; filename="source.png"`)
+ imageHeader.Set("Content-Type", "image/png")
+ imagePart, err := writer.CreatePart(imageHeader)
+ require.NoError(t, err)
+ _, err = imagePart.Write([]byte("source-image-bytes"))
+ require.NoError(t, err)
+
+ maskHeader := make(textproto.MIMEHeader)
+ maskHeader.Set("Content-Disposition", `form-data; name="mask"; filename="mask.png"`)
+ maskHeader.Set("Content-Type", "image/png")
+ maskPart, err := writer.CreatePart(maskHeader)
+ require.NoError(t, err)
+ _, err = maskPart.Write([]byte("mask-image-bytes"))
+ require.NoError(t, err)
+
+ require.NoError(t, writer.Close())
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Len(t, parsed.Uploads, 1)
+ require.NotNil(t, parsed.MaskUpload)
+ require.True(t, parsed.HasMask)
+ require.Equal(t, "png", parsed.OutputFormat)
+ require.Equal(t, "high", parsed.InputFidelity)
+ require.NotNil(t, parsed.OutputCompression)
+ require.Equal(t, 80, *parsed.OutputCompression)
+ require.NotNil(t, parsed.PartialImages)
+ require.Equal(t, 2, *parsed.PartialImages)
+ require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
+}
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_PromptOnlyDefaultsRemainBasic(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"prompt":"draw a cat"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Equal(t, "gpt-image-2", parsed.Model)
+ require.Equal(t, OpenAIImagesCapabilityBasic, parsed.RequiredCapability)
+}
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_ExplicitSizeRequiresNativeCapability(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"prompt":"draw a cat","size":"1024x1024"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
+}
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_RejectsNonImageModel(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-5.4","prompt":"draw a cat"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.Nil(t, parsed)
+ require.ErrorContains(t, err, `images endpoint requires an image model, got "gpt-5.4"`)
+}
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSONEditURLs(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{
+ "model":"gpt-image-2",
+ "prompt":"replace the background",
+ "images":[{"image_url":"https://example.com/source.png"}],
+ "mask":{"image_url":"https://example.com/mask.png"},
+ "input_fidelity":"high",
+ "output_compression":90,
+ "partial_images":2,
+ "response_format":"url"
+ }`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Equal(t, []string{"https://example.com/source.png"}, parsed.InputImageURLs)
+ require.Equal(t, "https://example.com/mask.png", parsed.MaskImageURL)
+ require.Equal(t, "high", parsed.InputFidelity)
+ require.NotNil(t, parsed.OutputCompression)
+ require.Equal(t, 90, *parsed.OutputCompression)
+ require.NotNil(t, parsed.PartialImages)
+ require.Equal(t, 2, *parsed.PartialImages)
+ require.True(t, parsed.HasMask)
+ require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
+}
+
+func TestCollectOpenAIImagePointers_RecognizesDirectAssets(t *testing.T) {
+ items := collectOpenAIImagePointers([]byte(`{
+ "revised_prompt": "cat astronaut",
+ "parts": [
+ {"b64_json":"QUJD"},
+ {"download_url":"https://files.example.com/image.png?sig=1"},
+ {"asset_pointer":"file-service://file_123"}
+ ]
+ }`))
+
+ require.Len(t, items, 3)
+ var sawBase64, sawURL, sawPointer bool
+ for _, item := range items {
+ if item.B64JSON == "QUJD" {
+ sawBase64 = true
+ require.Equal(t, "cat astronaut", item.Prompt)
+ }
+ if item.DownloadURL == "https://files.example.com/image.png?sig=1" {
+ sawURL = true
+ }
+ if item.Pointer == "file-service://file_123" {
+ sawPointer = true
+ }
+ }
+ require.True(t, sawBase64)
+ require.True(t, sawURL)
+ require.True(t, sawPointer)
+}
+
+func TestResolveOpenAIImageBytes_PrefersInlineBase64(t *testing.T) {
+ data, err := resolveOpenAIImageBytes(context.Background(), nil, nil, "", openAIImagePointerInfo{
+ B64JSON: "data:image/png;base64,QUJD",
+ })
+ require.NoError(t, err)
+ require.Equal(t, []byte("ABC"), data)
+}
+
+func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ }
+
+ require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityBasic))
+ require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative))
+}
+
+func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) {
+ require.Equal(t,
+ "https://image-upstream.example/v1/images/generations",
+ buildOpenAIImagesURL("https://image-upstream.example/v1", openAIImagesGenerationsEndpoint),
+ )
+ require.Equal(t,
+ "https://image-upstream.example/v1/images/edits",
+ buildOpenAIImagesURL("https://image-upstream.example/v1/", openAIImagesEditsEndpoint),
+ )
+ require.Equal(t,
+ "https://image-upstream.example/v1/images/generations",
+ buildOpenAIImagesURL("https://image-upstream.example", openAIImagesGenerationsEndpoint),
+ )
+ require.Equal(t,
+ "https://image-upstream.example/v1/images/generations",
+ buildOpenAIImagesURL("https://image-upstream.example/v1/images/generations", openAIImagesGenerationsEndpoint),
+ )
+}
+
+type openAIImageTestSSEEvent struct {
+ Name string
+ Data string
+}
+
+func parseOpenAIImageTestSSEEvents(body string) []openAIImageTestSSEEvent {
+ chunks := strings.Split(body, "\n\n")
+ events := make([]openAIImageTestSSEEvent, 0, len(chunks))
+ for _, chunk := range chunks {
+ chunk = strings.TrimSpace(chunk)
+ if chunk == "" {
+ continue
+ }
+ var event openAIImageTestSSEEvent
+ for _, line := range strings.Split(chunk, "\n") {
+ switch {
+ case strings.HasPrefix(line, "event: "):
+ event.Name = strings.TrimSpace(strings.TrimPrefix(line, "event: "))
+ case strings.HasPrefix(line, "data: "):
+ event.Data = strings.TrimSpace(strings.TrimPrefix(line, "data: "))
+ }
+ }
+ if event.Name != "" || event.Data != "" {
+ events = append(events, event)
+ }
+ }
+ return events
+}
+
+func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string) (openAIImageTestSSEEvent, bool) {
+ for _, event := range events {
+ if event.Name == name {
+ return event, true
+ }
+ }
+ return openAIImageTestSSEEvent{}, false
+}
+
+func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+ c.Set("api_key", &APIKey{ID: 42})
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"text/event-stream"},
+ "X-Request-Id": []string{"req_img_123"},
+ },
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ }
+ svc.httpUpstream = upstream
+
+ account := &Account{
+ ID: 1,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "token-123",
+ "chatgpt_account_id": "acct-123",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "gpt-image-2", result.Model)
+ require.Equal(t, "gpt-image-2", result.UpstreamModel)
+ require.Equal(t, 1, result.ImageCount)
+ require.Equal(t, 11, result.Usage.InputTokens)
+ require.Equal(t, 22, result.Usage.OutputTokens)
+ require.Equal(t, 7, result.Usage.ImageOutputTokens)
+
+ require.NotNil(t, upstream.lastReq)
+ require.Equal(t, chatgptCodexURL, upstream.lastReq.URL.String())
+ require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
+ require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type"))
+ require.Equal(t, "text/event-stream", upstream.lastReq.Header.Get("Accept"))
+ require.Equal(t, "acct-123", upstream.lastReq.Header.Get("chatgpt-account-id"))
+ require.Equal(t, "responses=experimental", upstream.lastReq.Header.Get("OpenAI-Beta"))
+
+ require.Equal(t, openAIImagesResponsesMainModel, gjson.GetBytes(upstream.lastBody, "model").String())
+ require.True(t, gjson.GetBytes(upstream.lastBody, "stream").Bool())
+ require.Equal(t, "image_generation", gjson.GetBytes(upstream.lastBody, "tools.0.type").String())
+ require.Equal(t, "generate", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
+ require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
+ require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String())
+ require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String())
+ require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists())
+ require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String())
+ require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
+ require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
+}
+
+func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","response_format":"b64_json"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{
+ cfg: &config.Config{},
+ httpUpstream: &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ "X-Request-Id": []string{"req_img_apikey"},
+ },
+ Body: io.NopCloser(strings.NewReader(`{"created":1710000007,"data":[{"b64_json":"aGVsbG8=","revised_prompt":"draw a cat"}]}`)),
+ },
+ },
+ }
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+
+ account := &Account{
+ ID: 6,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "test-api-key",
+ "base_url": "https://image-upstream.example/v1",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 1, result.ImageCount)
+ require.Equal(t, "gpt-image-2", result.Model)
+ require.Equal(t, "gpt-image-2", result.UpstreamModel)
+
+ upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder)
+ require.True(t, ok)
+ require.NotNil(t, upstream.lastReq)
+ require.Equal(t, "https://image-upstream.example/v1/images/generations", upstream.lastReq.URL.String())
+ require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
+ require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type"))
+ require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "model").String())
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
+}
+
+func TestOpenAIGatewayServiceForwardImages_APIKeyEditUsesConfiguredV1BaseURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ require.NoError(t, writer.WriteField("model", "gpt-image-2"))
+ require.NoError(t, writer.WriteField("prompt", "replace background"))
+ imagePart, err := writer.CreateFormFile("image", "source.png")
+ require.NoError(t, err)
+ _, err = imagePart.Write([]byte("png-image-content"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Close())
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{
+ cfg: &config.Config{},
+ httpUpstream: &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ "X-Request-Id": []string{"req_img_edit_apikey"},
+ },
+ Body: io.NopCloser(strings.NewReader(`{"created":1710000008,"data":[{"b64_json":"ZWRpdGVk","revised_prompt":"replace background"}]}`)),
+ },
+ },
+ }
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
+ require.NoError(t, err)
+
+ account := &Account{
+ ID: 7,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "test-api-key",
+ "base_url": "https://image-upstream.example/v1/",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 1, result.ImageCount)
+
+ upstream, ok := svc.httpUpstream.(*httpUpstreamRecorder)
+ require.True(t, ok)
+ require.NotNil(t, upstream.lastReq)
+ require.Equal(t, "https://image-upstream.example/v1/images/edits", upstream.lastReq.URL.String())
+ require.Equal(t, "Bearer test-api-key", upstream.lastReq.Header.Get("Authorization"))
+ require.Contains(t, upstream.lastReq.Header.Get("Content-Type"), "multipart/form-data")
+ require.Contains(t, string(upstream.lastBody), `name="model"`)
+ require.Contains(t, string(upstream.lastBody), "gpt-image-2")
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
+}
+
+func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"text/event-stream"},
+ "X-Request-Id": []string{"req_img_stream"},
+ },
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000001,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
+ "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\",\"background\":\"auto\"}\n\n" +
+ "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000001,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ }
+ svc.httpUpstream = upstream
+
+ account := &Account{
+ ID: 2,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "token-123",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.True(t, result.Stream)
+ require.Equal(t, 1, result.ImageCount)
+ events := parseOpenAIImageTestSSEEvents(rec.Body.String())
+ partial, ok := findOpenAIImageTestSSEEvent(events, "image_generation.partial_image")
+ require.True(t, ok)
+ require.Equal(t, "image_generation.partial_image", gjson.Get(partial.Data, "type").String())
+ require.Equal(t, int64(1710000001), gjson.Get(partial.Data, "created_at").Int())
+ require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String())
+ require.Equal(t, "data:image/png;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String())
+ require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String())
+ require.Equal(t, "png", gjson.Get(partial.Data, "output_format").String())
+ require.Equal(t, "high", gjson.Get(partial.Data, "quality").String())
+ require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String())
+ require.Equal(t, "auto", gjson.Get(partial.Data, "background").String())
+
+ completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
+ require.True(t, ok)
+ require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String())
+ require.Equal(t, int64(1710000001), gjson.Get(completed.Data, "created_at").Int())
+ require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String())
+ require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String())
+ require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String())
+ require.Equal(t, "png", gjson.Get(completed.Data, "output_format").String())
+ require.Equal(t, "high", gjson.Get(completed.Data, "quality").String())
+ require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String())
+ require.Equal(t, "auto", gjson.Get(completed.Data, "background").String())
+ require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
+ require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
+}
+
+func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ require.NoError(t, writer.WriteField("model", "gpt-image-2"))
+ require.NoError(t, writer.WriteField("prompt", "replace background with aurora"))
+ require.NoError(t, writer.WriteField("input_fidelity", "high"))
+ require.NoError(t, writer.WriteField("output_format", "webp"))
+ require.NoError(t, writer.WriteField("quality", "high"))
+
+ imageHeader := make(textproto.MIMEHeader)
+ imageHeader.Set("Content-Disposition", `form-data; name="image"; filename="source.png"`)
+ imageHeader.Set("Content-Type", "image/png")
+ imagePart, err := writer.CreatePart(imageHeader)
+ require.NoError(t, err)
+ _, err = imagePart.Write([]byte("png-image-content"))
+ require.NoError(t, err)
+
+ maskHeader := make(textproto.MIMEHeader)
+ maskHeader.Set("Content-Disposition", `form-data; name="mask"; filename="mask.png"`)
+ maskHeader.Set("Content-Type", "image/png")
+ maskPart, err := writer.CreatePart(maskHeader)
+ require.NoError(t, err)
+ _, err = maskPart.Write([]byte("png-mask-content"))
+ require.NoError(t, err)
+
+ require.NoError(t, writer.Close())
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes()))
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+ c.Set("api_key", &APIKey{ID: 100})
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes())
+ require.NoError(t, err)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"text/event-stream"},
+ "X-Request-Id": []string{"req_img_edit_123"},
+ },
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000002,\"usage\":{\"input_tokens\":13,\"output_tokens\":21,\"output_tokens_details\":{\"image_tokens\":8}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\",\"quality\":\"high\"}]}}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ }
+ svc.httpUpstream = upstream
+
+ account := &Account{
+ ID: 3,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "token-123",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 1, result.ImageCount)
+ require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String())
+ require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
+ require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.input_fidelity").Exists())
+ require.Equal(t, "webp", gjson.GetBytes(upstream.lastBody, "tools.0.output_format").String())
+ require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String(), "data:image/png;base64,"))
+ require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String(), "data:image/png;base64,"))
+ require.Equal(t, "replace background with aurora", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String())
+ require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
+ require.Equal(t, "replace background with aurora", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String())
+}
+
+func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{
+ "model":"gpt-image-2",
+ "prompt":"replace background with aurora",
+ "images":[{"image_url":"https://example.com/source.png"}],
+ "mask":{"image_url":"https://example.com/mask.png"},
+ "stream":true,
+ "response_format":"url"
+ }`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"text/event-stream"},
+ },
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000003,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" +
+ "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"webp\",\"background\":\"transparent\"}\n\n" +
+ "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000003,\"usage\":{\"input_tokens\":7,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\"}]}}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ }
+ svc.httpUpstream = upstream
+
+ account := &Account{
+ ID: 4,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "token-123",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 1, result.ImageCount)
+ require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String())
+ require.Equal(t, "https://example.com/source.png", gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String())
+ require.Equal(t, "https://example.com/mask.png", gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String())
+ events := parseOpenAIImageTestSSEEvents(rec.Body.String())
+ partial, ok := findOpenAIImageTestSSEEvent(events, "image_edit.partial_image")
+ require.True(t, ok)
+ require.Equal(t, "image_edit.partial_image", gjson.Get(partial.Data, "type").String())
+ require.Equal(t, int64(1710000003), gjson.Get(partial.Data, "created_at").Int())
+ require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String())
+ require.Equal(t, "data:image/webp;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String())
+ require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String())
+ require.Equal(t, "webp", gjson.Get(partial.Data, "output_format").String())
+ require.Equal(t, "high", gjson.Get(partial.Data, "quality").String())
+ require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String())
+ require.Equal(t, "transparent", gjson.Get(partial.Data, "background").String())
+
+ completed, ok := findOpenAIImageTestSSEEvent(events, "image_edit.completed")
+ require.True(t, ok)
+ require.Equal(t, "image_edit.completed", gjson.Get(completed.Data, "type").String())
+ require.Equal(t, int64(1710000003), gjson.Get(completed.Data, "created_at").Int())
+ require.Equal(t, "ZWRpdGVk", gjson.Get(completed.Data, "b64_json").String())
+ require.Equal(t, "data:image/webp;base64,ZWRpdGVk", gjson.Get(completed.Data, "url").String())
+ require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String())
+ require.Equal(t, "webp", gjson.Get(completed.Data, "output_format").String())
+ require.Equal(t, "high", gjson.Get(completed.Data, "quality").String())
+ require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String())
+ require.Equal(t, "transparent", gjson.Get(completed.Data, "background").String())
+ require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
+ require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
+}
+
+func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) {
+ parsed := &OpenAIImagesRequest{
+ Endpoint: openAIImagesGenerationsEndpoint,
+ Model: "gpt-image-2",
+ Prompt: "draw a cat",
+ N: 2,
+ }
+
+ body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
+ require.NoError(t, err)
+ require.NotNil(t, body)
+ require.False(t, gjson.GetBytes(body, "tools.0.n").Exists())
+ require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String())
+ require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String())
+}
+
+func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) {
+ parsed := &OpenAIImagesRequest{
+ Endpoint: openAIImagesEditsEndpoint,
+ Model: "gpt-image-2",
+ Prompt: "replace background",
+ InputFidelity: "high",
+ InputImageURLs: []string{
+ "https://example.com/source.png",
+ },
+ }
+
+ body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2")
+ require.NoError(t, err)
+ require.NotNil(t, body)
+ require.False(t, gjson.GetBytes(body, "tools.0.input_fidelity").Exists())
+ require.Equal(t, "edit", gjson.GetBytes(body, "tools.0.action").String())
+}
+
+func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testing.T) {
+ body := []byte(
+ "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000004}}\n\n" +
+ "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\"}}\n\n" +
+ "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000004,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" +
+ "data: [DONE]\n\n",
+ )
+
+ results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body)
+ require.NoError(t, err)
+ require.True(t, foundFinal)
+ require.Equal(t, int64(1710000004), createdAt)
+ require.Len(t, results, 1)
+ require.Equal(t, "aGVsbG8=", results[0].Result)
+ require.Equal(t, "draw a cat", results[0].RevisedPrompt)
+ require.Equal(t, "png", firstMeta.OutputFormat)
+ require.JSONEq(t, `{"images":1}`, string(usageRaw))
+}
+
+func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"text/event-stream"},
+ "X-Request-Id": []string{"req_img_stream_output_item_done"},
+ },
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" +
+ "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000005,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ }
+ svc.httpUpstream = upstream
+
+ account := &Account{
+ ID: 5,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "token-123",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.True(t, result.Stream)
+ require.Equal(t, 1, result.ImageCount)
+ events := parseOpenAIImageTestSSEEvents(rec.Body.String())
+ completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
+ require.True(t, ok)
+ require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String())
+ require.Equal(t, int64(1710000005), gjson.Get(completed.Data, "created_at").Int())
+ require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String())
+ require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String())
+ require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String())
+ require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
+ require.NotContains(t, rec.Body.String(), "event: error")
+}
diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go
index 9bf3fba3..f332633c 100644
--- a/backend/internal/service/openai_model_mapping.go
+++ b/backend/internal/service/openai_model_mapping.go
@@ -1,5 +1,7 @@
package service
+import "strings"
+
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
// forwarding. Group-level default mapping only applies when the account itself
// did not match any explicit model_mapping rule.
@@ -12,8 +14,47 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo
}
mappedModel, matched := account.ResolveMappedModel(requestedModel)
- if !matched && defaultMappedModel != "" {
+ if !matched && defaultMappedModel != "" && !isExplicitCodexModel(requestedModel) {
return defaultMappedModel
}
return mappedModel
}
+
+func isExplicitCodexModel(model string) bool {
+ model = strings.TrimSpace(model)
+ if model == "" {
+ return false
+ }
+ if strings.Contains(model, "/") {
+ parts := strings.Split(model, "/")
+ model = parts[len(parts)-1]
+ }
+ model = strings.ToLower(strings.TrimSpace(model))
+ if getNormalizedCodexModel(model) != "" {
+ return true
+ }
+ if strings.HasSuffix(model, "-openai-compact") {
+ base := strings.TrimSuffix(model, "-openai-compact")
+ return getNormalizedCodexModel(base) != ""
+ }
+ return false
+}
+
+// resolveOpenAICompactForwardModel determines the compact-only upstream model
+// for /responses/compact requests. It never affects normal /responses traffic.
+// When no compact-specific mapping matches, the input model is returned as-is.
+func resolveOpenAICompactForwardModel(account *Account, model string) string {
+ trimmedModel := strings.TrimSpace(model)
+ if trimmedModel == "" || account == nil {
+ return trimmedModel
+ }
+
+ mappedModel, matched := account.ResolveCompactMappedModel(trimmedModel)
+ if !matched {
+ return trimmedModel
+ }
+ if trimmedMapped := strings.TrimSpace(mappedModel); trimmedMapped != "" {
+ return trimmedMapped
+ }
+ return trimmedModel
+}
diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go
index cda7e369..4802c089 100644
--- a/backend/internal/service/openai_model_mapping_test.go
+++ b/backend/internal/service/openai_model_mapping_test.go
@@ -15,10 +15,19 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
account: &Account{
Credentials: map[string]any{},
},
- requestedModel: "gpt-5.4",
+ requestedModel: "claude-opus-4-6",
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-4o-mini",
},
+ {
+ name: "preserves explicit gpt-5.4 instead of group default",
+ account: &Account{
+ Credentials: map[string]any{},
+ },
+ requestedModel: "gpt-5.4",
+ defaultMappedModel: "gpt-4o-mini",
+ expectedModel: "gpt-5.4",
+ },
{
name: "preserves exact passthrough mapping instead of group default",
account: &Account{
@@ -58,6 +67,42 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
defaultMappedModel: "gpt-4o-mini",
expectedModel: "gpt-5.4",
},
+ {
+ name: "preserves codex spark instead of group default",
+ account: &Account{
+ Credentials: map[string]any{},
+ },
+ requestedModel: "gpt-5.3-codex-spark",
+ defaultMappedModel: "gpt-5.4",
+ expectedModel: "gpt-5.3-codex-spark",
+ },
+ {
+ name: "preserves gpt-5.5 instead of group default",
+ account: &Account{
+ Credentials: map[string]any{},
+ },
+ requestedModel: "gpt-5.5",
+ defaultMappedModel: "gpt-5.4",
+ expectedModel: "gpt-5.5",
+ },
+ {
+ name: "preserves openai namespaced gpt-5.5 instead of group default",
+ account: &Account{
+ Credentials: map[string]any{},
+ },
+ requestedModel: "openai/gpt-5.5",
+ defaultMappedModel: "gpt-5.4",
+ expectedModel: "openai/gpt-5.5",
+ },
+ {
+ name: "preserves compact gpt-5.5 instead of group default",
+ account: &Account{
+ Credentials: map[string]any{},
+ },
+ requestedModel: "gpt-5.5-openai-compact",
+ defaultMappedModel: "gpt-5.4",
+ expectedModel: "gpt-5.5-openai-compact",
+ },
}
for _, tt := range tests {
@@ -69,14 +114,14 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
}
}
-func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) {
+func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *testing.T) {
account := &Account{
Credentials: map[string]any{},
}
withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
- if withoutDefault != "gpt-5.1" {
- t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1")
+ if withoutDefault != "gpt-5.4" {
+ t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4")
}
withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
@@ -85,12 +130,81 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
}
}
+func TestResolveOpenAICompactForwardModel(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ model string
+ expectedModel string
+ }{
+ {
+ name: "nil account keeps original model",
+ account: nil,
+ model: "gpt-5.4",
+ expectedModel: "gpt-5.4",
+ },
+ {
+ name: "missing compact mapping keeps original model",
+ account: &Account{
+ Credentials: map[string]any{},
+ },
+ model: "gpt-5.4",
+ expectedModel: "gpt-5.4",
+ },
+ {
+ name: "exact compact mapping overrides model",
+ account: &Account{
+ Credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-5.4": "gpt-5.4-openai-compact",
+ },
+ },
+ },
+ model: "gpt-5.4",
+ expectedModel: "gpt-5.4-openai-compact",
+ },
+ {
+ name: "wildcard compact mapping overrides model",
+ account: &Account{
+ Credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-5.*": "gpt-5-openai-compact",
+ },
+ },
+ },
+ model: "gpt-5.4",
+ expectedModel: "gpt-5-openai-compact",
+ },
+ {
+ name: "passthrough compact mapping remains unchanged",
+ account: &Account{
+ Credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-5.4": "gpt-5.4",
+ },
+ },
+ },
+ model: "gpt-5.4",
+ expectedModel: "gpt-5.4",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := resolveOpenAICompactForwardModel(tt.account, tt.model); got != tt.expectedModel {
+ t.Fatalf("resolveOpenAICompactForwardModel(...) = %q, want %q", got, tt.expectedModel)
+ }
+ })
+ }
+}
+
func TestNormalizeCodexModel(t *testing.T) {
cases := map[string]string{
- "gpt-5.3-codex-spark": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-high": "gpt-5.3-codex",
- "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex",
+ "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
+ "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt-5.3": "gpt-5.3-codex",
+ "gpt-image-2": "gpt-image-2",
}
for input, expected := range cases {
@@ -111,7 +225,7 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) {
name: "oauth keeps codex normalization behavior",
account: &Account{Type: AccountTypeOAuth},
model: "gemini-3-flash-preview",
- want: "gpt-5.1",
+ want: "gpt-5.4",
},
{
name: "apikey preserves custom compatible model",
diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go
index 69c9de42..049ffdd8 100644
--- a/backend/internal/service/openai_oauth_passthrough_test.go
+++ b/backend/internal/service/openai_oauth_passthrough_test.go
@@ -734,7 +734,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
require.NoError(t, err)
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
- require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent"))
+ require.Equal(t, "codex_cli_rs/0.125.0", upstream.lastReq.Header.Get("User-Agent"))
}
func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) {
diff --git a/backend/internal/service/openai_passthrough_normalization_test.go b/backend/internal/service/openai_passthrough_normalization_test.go
new file mode 100644
index 00000000..492ff610
--- /dev/null
+++ b/backend/internal/service/openai_passthrough_normalization_test.go
@@ -0,0 +1,33 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestNormalizeOpenAIPassthroughOAuthBody_RemovesUnsupportedUser(t *testing.T) {
+ body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"prompt_cache_retention":"24h","safety_identifier":"sid","stream_options":{"include_usage":true}}`)
+
+ normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, false)
+ require.NoError(t, err)
+ require.True(t, changed)
+ for _, field := range openAIChatGPTInternalUnsupportedFields {
+ require.False(t, gjson.GetBytes(normalized, field).Exists(), "%s should be stripped", field)
+ }
+ require.True(t, gjson.GetBytes(normalized, "stream").Bool())
+ require.False(t, gjson.GetBytes(normalized, "store").Bool())
+}
+
+func TestNormalizeOpenAIPassthroughOAuthBody_CompactRemovesUnsupportedUser(t *testing.T) {
+ body := []byte(`{"model":"gpt-5.4","input":"hello","user":"user_123","metadata":{"user_id":"user_123"},"stream":true,"store":true}`)
+
+ normalized, changed, err := normalizeOpenAIPassthroughOAuthBody(body, true)
+ require.NoError(t, err)
+ require.True(t, changed)
+ require.False(t, gjson.GetBytes(normalized, "user").Exists())
+ require.False(t, gjson.GetBytes(normalized, "metadata").Exists())
+ require.False(t, gjson.GetBytes(normalized, "stream").Exists())
+ require.False(t, gjson.GetBytes(normalized, "store").Exists())
+}
diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go
index dea3c172..c0f98de4 100644
--- a/backend/internal/service/openai_tool_continuation.go
+++ b/backend/internal/service/openai_tool_continuation.go
@@ -21,7 +21,7 @@ type FunctionCallOutputValidation struct {
}
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
-// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
+// 满足以下任一信号即视为续链:previous_response_id、input 内包含工具输出/item_reference、
// 或显式声明 tools/tool_choice。
func NeedsToolContinuation(reqBody map[string]any) bool {
if reqBody == nil {
@@ -46,7 +46,7 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
continue
}
itemType, _ := itemMap["type"].(string)
- if itemType == "function_call_output" || itemType == "item_reference" {
+ if isCodexToolCallItemType(itemType) || itemType == "item_reference" {
return true
}
}
diff --git a/backend/internal/service/openai_tool_continuation_test.go b/backend/internal/service/openai_tool_continuation_test.go
index fe737ad6..3f415d9d 100644
--- a/backend/internal/service/openai_tool_continuation_test.go
+++ b/backend/internal/service/openai_tool_continuation_test.go
@@ -17,6 +17,9 @@ func TestNeedsToolContinuationSignals(t *testing.T) {
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
+ {name: "tool_search_output", body: map[string]any{"input": []any{map[string]any{"type": "tool_search_output"}}}, want: true},
+ {name: "custom_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "custom_tool_call_output"}}}, want: true},
+ {name: "mcp_tool_call_output", body: map[string]any{"input": []any{map[string]any{"type": "mcp_tool_call_output"}}}, want: true},
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go
index a5b97ca9..4005a921 100644
--- a/backend/internal/service/openai_ws_account_sticky_test.go
+++ b/backend/internal/service/openai_ws_account_sticky_test.go
@@ -37,7 +37,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour))
- selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil)
+ selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
@@ -77,7 +77,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_RateLimitedMiss(
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_rl", account.ID, time.Hour))
- selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil)
+ selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_rl", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "限额中的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_rl")
@@ -129,7 +129,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_DBRuntimeRecheck
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_db_rl", dbAccount.ID, time.Hour))
- selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil)
+ selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_db_rl", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "DB 中已限流的账号不应继续命中 previous_response_id 粘连")
boundAccountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_prev_db_rl")
@@ -164,7 +164,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *test
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour))
- selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}})
+ selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}}, false)
require.NoError(t, err)
require.Nil(t, selection)
}
@@ -197,7 +197,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour))
- selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil)
+ selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil, false)
require.NoError(t, err)
require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连")
}
@@ -258,7 +258,7 @@ func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour))
- selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil)
+ selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil, false)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go
index 83849bf3..d1386b1b 100644
--- a/backend/internal/service/openai_ws_forwarder.go
+++ b/backend/internal/service/openai_ws_forwarder.go
@@ -1366,16 +1366,27 @@ func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string
func shouldInferIngressFunctionCallOutputPreviousResponseID(
storeDisabled bool,
turn int,
- hasFunctionCallOutput bool,
+ signals ToolContinuationSignals,
currentPreviousResponseID string,
expectedPreviousResponseID string,
) bool {
- if !storeDisabled || turn <= 1 || !hasFunctionCallOutput {
+ if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput {
return false
}
if strings.TrimSpace(currentPreviousResponseID) != "" {
return false
}
+ if signals.HasFunctionCallOutputMissingCallID {
+ return false
+ }
+ // If the client already sent the actual tool-call context, treat this as
+ // a full replay / self-contained continuation payload rather than
+ // downgrading it into an inferred delta continuation. item_reference alone
+ // is not enough on the store=false WS path: it still needs a valid prior
+ // response anchor so upstream can resolve the referenced function_call.
+ if signals.HasToolCallContext {
+ return false
+ }
return strings.TrimSpace(expectedPreviousResponseID) != ""
}
@@ -2366,6 +2377,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return errors.New("token is empty")
}
+ // 预取一次 OpenAI Fast Policy settings,绑定到 ctx,让该 WS session
+ // 内所有帧的 evaluateOpenAIFastPolicy 调用复用同一份快照,避免每帧
+ // 进入 DB / settingRepo。Trade-off 见 withOpenAIFastPolicyContext 注释。
+ if s.settingService != nil {
+ if settings, err := s.settingService.GetOpenAIFastPolicySettings(ctx); err == nil && settings != nil {
+ ctx = withOpenAIFastPolicyContext(ctx, settings)
+ }
+ }
+
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
ingressMode := OpenAIWSIngressModeCtxPool
@@ -2524,6 +2544,44 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
normalized = next
}
+ // Apply OpenAI Fast Policy on the response.create frame using the same
+ // evaluator/normalize/scope rules as the HTTP entrypoints. This is the
+ // single integration point for all WS ingress turns (first + follow-up
+ // frames flow through here).
+ //
+ // Model fallback: parseClientPayload above rejects any frame whose
+ // "model" field is missing (line ~2493-2500), so by the time we
+ // reach this point upstreamModel is always derived from a non-empty
+ // per-frame model. The capturedSessionModel fallback used in the
+ // passthrough adapter is therefore not needed in this path.
+ policyApplied, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, upstreamModel, normalized)
+ if policyErr != nil {
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", policyErr)
+ }
+ if blocked != nil {
+ // Send a Realtime-style error event to the client first, then
+ // signal the handler to close the connection with PolicyViolation.
+ // We intentionally do NOT forward this frame upstream.
+ //
+ // coder/websocket@v1.8.14 Conn.Write is synchronous and flushes
+ // the underlying bufio writer before returning (write.go:42 →
+ // 307-311), and the subsequent close handshake re-acquires the
+ // same writeFrameMu, so the error event is guaranteed to reach
+ // the kernel send buffer before any close frame is queued.
+ eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
+ if eventBytes != nil {
+ writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
+ _ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
+ cancel()
+ }
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ blocked.Message,
+ blocked,
+ )
+ }
+ normalized = policyApplied
+
return openAIWSClientPayload{
payloadRaw: normalized,
rawForHash: trimmed,
@@ -3132,13 +3190,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
skipBeforeTurn = false
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
expectedPrev := strings.TrimSpace(lastTurnResponseID)
- hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
+ toolSignals := ToolContinuationSignals{
+ HasFunctionCallOutput: gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists(),
+ }
+ if toolSignals.HasFunctionCallOutput {
+ var currentReqBody map[string]any
+ if err := json.Unmarshal(currentPayload, ¤tReqBody); err == nil {
+ toolSignals = AnalyzeToolContinuationSignals(currentReqBody)
+ }
+ }
+ hasFunctionCallOutput := toolSignals.HasFunctionCallOutput
// store=false + function_call_output 场景必须有续链锚点。
// 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。
if shouldInferIngressFunctionCallOutputPreviousResponseID(
storeDisabled,
turn,
- hasFunctionCallOutput,
+ toolSignals,
currentPreviousResponseID,
expectedPrev,
) {
@@ -3800,6 +3867,7 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
previousResponseID string,
requestedModel string,
excludedIDs map[int64]struct{},
+ requireCompact bool,
) (*AccountSelectionResult, error) {
if s == nil {
return nil, nil
@@ -3840,11 +3908,16 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
return nil, nil
}
- account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
+ account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
return nil, nil
}
+ // 兜底:若上游 compact 能力刚被探测为不支持,但 sticky 还在,需要主动放弃。
+ if requireCompact && openAICompactSupportTier(account) == 0 {
+ _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
+ return nil, nil
+ }
result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go
index 6bf9a9ff..30fd4142 100644
--- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go
+++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go
@@ -1354,6 +1354,274 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFun
require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id")
}
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenToolCallContextPresent(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ctx_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{captureConn},
+ }
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 114,
+ Name: "openai-ingress-tool-context",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_auto_prev_ctx_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call","call_id":"call_ctx_1","name":"shell","arguments":"{}"},{"type":"function_call_output","call_id":"call_ctx_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_auto_prev_ctx_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ require.Equal(t, 1, captureDialer.DialCount())
+ require.Len(t, captureConn.writes, 2)
+ require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "请求已包含 function_call 上下文时不应自动补齐 previous_response_id")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachWhenOnlyItemReferencesPresent(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_ref_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{captureConn},
+ }
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 115,
+ Name: "openai-ingress-item-reference",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_auto_prev_ref_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"item_reference","id":"call_ref_1"},{"type":"function_call_output","call_id":"call_ref_1","output":"ok"},{"type":"message","role":"user","content":[{"type":"input_text","text":"retry"}]}]}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_auto_prev_ref_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ require.Equal(t, 1, captureDialer.DialCount())
+ require.Len(t, captureConn.writes, 2)
+ require.Equal(t, "resp_auto_prev_ref_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "仅有 item_reference 不足以自包含 function_call_output,应回填上一轮响应 ID")
+}
+
func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
gin.SetMode(gin.TestMode)
prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go
index ff35cb01..c735f50a 100644
--- a/backend/internal/service/openai_ws_forwarder_ingress_test.go
+++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go
@@ -232,67 +232,91 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
name string
storeDisabled bool
turn int
- hasFunctionCallOutput bool
+ signals ToolContinuationSignals
currentPreviousResponse string
expectedPrevious string
want bool
}{
{
- name: "infer_when_all_conditions_match",
- storeDisabled: true,
- turn: 2,
- hasFunctionCallOutput: true,
- expectedPrevious: "resp_1",
- want: true,
+ name: "infer_when_all_conditions_match",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true},
+ expectedPrevious: "resp_1",
+ want: true,
},
{
- name: "skip_when_store_enabled",
- storeDisabled: false,
- turn: 2,
- hasFunctionCallOutput: true,
- expectedPrevious: "resp_1",
- want: false,
+ name: "skip_when_store_enabled",
+ storeDisabled: false,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true},
+ expectedPrevious: "resp_1",
+ want: false,
},
{
- name: "skip_on_first_turn",
- storeDisabled: true,
- turn: 1,
- hasFunctionCallOutput: true,
- expectedPrevious: "resp_1",
- want: false,
+ name: "skip_on_first_turn",
+ storeDisabled: true,
+ turn: 1,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true},
+ expectedPrevious: "resp_1",
+ want: false,
},
{
- name: "skip_without_function_call_output",
- storeDisabled: true,
- turn: 2,
- hasFunctionCallOutput: false,
- expectedPrevious: "resp_1",
- want: false,
+ name: "skip_without_function_call_output",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{},
+ expectedPrevious: "resp_1",
+ want: false,
},
{
name: "skip_when_request_already_has_previous_response_id",
storeDisabled: true,
turn: 2,
- hasFunctionCallOutput: true,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true},
currentPreviousResponse: "resp_client",
expectedPrevious: "resp_1",
want: false,
},
{
- name: "skip_when_last_turn_response_id_missing",
- storeDisabled: true,
- turn: 2,
- hasFunctionCallOutput: true,
- expectedPrevious: "",
- want: false,
+ name: "skip_when_last_turn_response_id_missing",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true},
+ expectedPrevious: "",
+ want: false,
},
{
- name: "trim_whitespace_before_judgement",
- storeDisabled: true,
- turn: 2,
- hasFunctionCallOutput: true,
- expectedPrevious: " resp_2 ",
- want: true,
+ name: "trim_whitespace_before_judgement",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true},
+ expectedPrevious: " resp_2 ",
+ want: true,
+ },
+ {
+ name: "skip_when_tool_call_context_already_present",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasToolCallContext: true},
+ expectedPrevious: "resp_2",
+ want: false,
+ },
+ {
+ name: "infer_when_only_item_reference_covers_call_ids",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasItemReferenceForAllCallIDs: true},
+ expectedPrevious: "resp_2",
+ want: true,
+ },
+ {
+ name: "skip_when_function_call_output_missing_call_id",
+ storeDisabled: true,
+ turn: 2,
+ signals: ToolContinuationSignals{HasFunctionCallOutput: true, HasFunctionCallOutputMissingCallID: true},
+ expectedPrevious: "resp_2",
+ want: false,
},
}
@@ -303,7 +327,7 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
got := shouldInferIngressFunctionCallOutputPreviousResponseID(
tt.storeDisabled,
tt.turn,
- tt.hasFunctionCallOutput,
+ tt.signals,
tt.currentPreviousResponse,
tt.expectedPrevious,
)
diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go
index 66e5db93..f3936de1 100644
--- a/backend/internal/service/openai_ws_protocol_forward_test.go
+++ b/backend/internal/service/openai_ws_protocol_forward_test.go
@@ -618,6 +618,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
+ nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go
index cda2e351..3dbb199a 100644
--- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go
+++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go
@@ -21,6 +21,109 @@ type openAIWSClientFrameConn struct {
conn *coderws.Conn
}
+// openAIWSPolicyEnforcingFrameConn wraps a client-side FrameConn and runs
+// every client→upstream frame through the OpenAI Fast Policy. It is the
+// passthrough-relay equivalent of the parseClientPayload integration in the
+// ingress session path. filter returns:
+// - newPayload, nil, nil: forward the (possibly mutated) payload
+// - _, *OpenAIFastBlockedError, nil: block — the wrapper sends an error
+// event via onBlock and surfaces a transport-level error so the relay
+// stops reading from the client.
+// - _, _, err: a transport error other than block.
+type openAIWSPolicyEnforcingFrameConn struct {
+ inner openaiwsv2.FrameConn
+ filter func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error)
+ onBlock func(blocked *OpenAIFastBlockedError)
+}
+
+var _ openaiwsv2.FrameConn = (*openAIWSPolicyEnforcingFrameConn)(nil)
+
+func (c *openAIWSPolicyEnforcingFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ if c == nil || c.inner == nil {
+ return coderws.MessageText, nil, errOpenAIWSConnClosed
+ }
+ msgType, payload, err := c.inner.ReadFrame(ctx)
+ if err != nil {
+ return msgType, payload, err
+ }
+ if c.filter == nil {
+ return msgType, payload, nil
+ }
+ updated, blocked, filterErr := c.filter(msgType, payload)
+ if filterErr != nil {
+ return msgType, payload, filterErr
+ }
+ if blocked != nil {
+ if c.onBlock != nil {
+ c.onBlock(blocked)
+ }
+ return msgType, nil, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
+ }
+ return msgType, updated, nil
+}
+
+func (c *openAIWSPolicyEnforcingFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
+ if c == nil || c.inner == nil {
+ return errOpenAIWSConnClosed
+ }
+ return c.inner.WriteFrame(ctx, msgType, payload)
+}
+
+func (c *openAIWSPolicyEnforcingFrameConn) Close() error {
+ if c == nil || c.inner == nil {
+ return nil
+ }
+ return c.inner.Close()
+}
+
+// openAIWSPassthroughPolicyModelForFrame returns the upstream-perspective
+// model name that should be passed to evaluateOpenAIFastPolicy for a single
+// passthrough WS frame. Mirrors the HTTP-side normalization
+// (account.GetMappedModel + normalizeOpenAIModelForUpstream) so the WS path
+// matches model whitelists identically.
+func openAIWSPassthroughPolicyModelForFrame(account *Account, payload []byte) string {
+ if account == nil || len(payload) == 0 {
+ return ""
+ }
+ original := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
+ if original == "" {
+ return ""
+ }
+ return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
+}
+
+// openAIWSPassthroughPolicyModelFromSessionFrame returns the upstream model
+// derived from a session.update frame's session.model field. Returns "" when
+// the frame is not a session.update event or carries no session.model. Used
+// by the per-frame policy filter (client→upstream direction) to keep
+// capturedSessionModel in sync with the session-level model the client may
+// rotate mid-session.
+//
+// Realtime / Responses WS lets the client change the session model after
+// the WS handshake via:
+//
+// {"type":"session.update","session":{"model":"gpt-5.5", ...}}
+//
+// If we only capture the model from the very first frame, a client can ship
+// gpt-4o on the first response.create (whitelisted as pass), then
+// session.update to gpt-5.5, then send response.create without "model" so
+// the per-frame resolver returns "" and the stale capturedSessionModel falls
+// back to gpt-4o — defeating the gpt-5.5 fast-policy filter.
+func openAIWSPassthroughPolicyModelFromSessionFrame(account *Account, payload []byte) string {
+ if account == nil || len(payload) == 0 {
+ return ""
+ }
+ frameType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
+ if frameType != "session.update" {
+ return ""
+ }
+ original := strings.TrimSpace(gjson.GetBytes(payload, "session.model").String())
+ if original == "" {
+ return ""
+ }
+ return normalizeOpenAIModelForUpstream(account, account.GetMappedModel(original))
+}
+
const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
@@ -77,7 +180,6 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
return errors.New("token is empty")
}
requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
- requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage)
requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String())
logOpenAIWSV2Passthrough(
"relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
@@ -88,6 +190,59 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
len(firstClientMessage),
)
+ // Apply OpenAI Fast Policy on the first response.create frame. Subsequent
+ // frames are filtered via a wrapping FrameConn below so every client→
+ // upstream frame goes through the same policy evaluator/normalize/scope as
+ // HTTP entrypoints.
+ //
+ // We capture the session-level model from the first frame here so the
+ // per-frame filter (below) can fall back to it when a follow-up frame
+ // omits "model" — Realtime clients are allowed to send response.create
+ // without re-stating the model, in which case the upstream uses the model
+ // negotiated at session.update time. Without this fallback, an empty
+ // model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
+ // silently passed through, defeating the policy on every frame after
+ // the first.
+ capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
+ updatedFirst, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, capturedSessionModel, firstClientMessage)
+ if policyErr != nil {
+ return fmt.Errorf("apply openai fast policy on first ws frame: %w", policyErr)
+ }
+ if blocked != nil {
+ // coder/websocket@v1.8.14 Conn.Write is synchronous: it acquires
+ // writeFrameMu, writes the entire frame, and Flushes the underlying
+ // bufio writer before returning (write.go:42 → write.go:307-311).
+ // The subsequent close handshake re-acquires the same writeFrameMu
+ // to send the close frame, so the error event is guaranteed to
+ // reach the kernel send buffer before any close frame is queued.
+ // No explicit flush hop is required here.
+ eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
+ if eventBytes != nil {
+ writeCtx, cancelWrite := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
+ _ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
+ cancelWrite()
+ }
+ return NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, blocked.Message, blocked)
+ }
+ firstClientMessage = updatedFirst
+
+ // 在 policy filter 之后再提取 service_tier 用于 billing 上报:filter
+ // 命中时 service_tier 已经从 firstClientMessage 中删除,billing 应当
+ // 反映上游实际处理的 tier(nil = default),而不是用户最初请求的
+ // "priority"。HTTP 入口(line ~2728 extractOpenAIServiceTier(reqBody))
+ // 与 WS ingress(openai_ws_forwarder.go:2991 取自 payload)的语义一致。
+ //
+ // 多轮 passthrough:OpenAI Realtime / Responses WS 协议允许客户端在
+ // 同一连接的不同 response.create 帧上发送不同 service_tier(参考
+ // codex-rs/core/src/client.rs build_responses_request 每次重新填值)。
+ // 因此使用 atomic.Pointer[string] 在 filter(runClientToUpstream
+ // goroutine)和 OnTurnComplete / final result(runUpstreamToClient
+ // goroutine)之间同步当前 turn 的 service_tier。
+ // extractOpenAIServiceTierFromBody 返回 *string,本身是指针类型,
+ // 可直接 Store/Load 而无需额外封装。
+ var requestServiceTierPtr atomic.Pointer[string]
+ requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(firstClientMessage))
+
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
@@ -152,9 +307,72 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
}
completedTurns := atomic.Int32{}
+ policyClientConn := &openAIWSPolicyEnforcingFrameConn{
+ inner: &openAIWSClientFrameConn{conn: clientConn},
+ // 注意线程安全:filter 仅在 runClientToUpstream 这一条
+ // goroutine 中被调用(passthrough_relay.go: ReadFrame loop),
+ // capturedSessionModel 的读写都发生在该 goroutine 内,因此无需
+ // 加锁/原子化。
+ filter: func(msgType coderws.MessageType, payload []byte) ([]byte, *OpenAIFastBlockedError, error) {
+ if msgType != coderws.MessageText {
+ return payload, nil, nil
+ }
+ // 在评估策略前先刷新 capturedSessionModel:客户端可能通过
+ // session.update 修改 session-level model(Realtime /
+ // Responses WS 协议允许),如果不刷新就会出现
+ // "首帧 model=gpt-4o(pass)→ session.update 改成 gpt-5.5
+ // → 不带 model 的 response.create fallback 到 gpt-4o" 的
+ // 绕过路径。这里只看 session.update 事件中的 session.model
+ // 字段,response.create 自己的 model 仍然由其本帧字段决定。
+ if updated := openAIWSPassthroughPolicyModelFromSessionFrame(account, payload); updated != "" {
+ capturedSessionModel = updated
+ }
+ // Per-frame model first; if the client omits "model" on a
+ // follow-up frame (legal in Realtime), fall back to the
+ // session-level model captured from the first frame so the
+ // model whitelist still resolves. An empty model would miss
+ // any whitelist and silently fall back to pass.
+ model := openAIWSPassthroughPolicyModelForFrame(account, payload)
+ if model == "" {
+ model = capturedSessionModel
+ }
+ out, blocked, policyErr := s.applyOpenAIFastPolicyToWSResponseCreate(ctx, account, model, payload)
+ // 多轮 passthrough billing:仅在成功(non-block / non-err)
+ // 的 response.create 帧上更新 requestServiceTierPtr,使用
+ // filter 处理后的 payload,与首帧 policy-after-extract 语义
+ // 保持一致(参见上方 extractOpenAIServiceTierFromBody 注释)。
+ // - 非 response.create 帧(response.cancel /
+ // conversation.item.create / session.update 等)不携带
+ // per-response service_tier,不应覆盖前一轮值。
+ // - blocked != nil:该帧不会发送上游,billing tier 应保持
+ // 上一轮值。
+ // - policyErr != nil:异常路径,保持上一轮值。
+ // - 不带 service_tier 的 response.create 会让
+ // extractOpenAIServiceTierFromBody 返回 nil;这里有意
+ // 覆盖(Store(nil)),因为 OpenAI 上游对该帧实际不传
+ // service_tier 时按 default 处理,billing 应如实反映。
+ if policyErr == nil && blocked == nil &&
+ strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.create" {
+ requestServiceTierPtr.Store(extractOpenAIServiceTierFromBody(out))
+ }
+ return out, blocked, policyErr
+ },
+ onBlock: func(blocked *OpenAIFastBlockedError) {
+ // See note above on Conn.Write being synchronous w.r.t. flush;
+ // no explicit flush is required to ensure the error event lands
+ // before the close frame.
+ eventBytes := buildOpenAIFastPolicyBlockedWSEvent(blocked)
+ if eventBytes == nil {
+ return
+ }
+ writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
+ _ = clientConn.Write(writeCtx, coderws.MessageText, eventBytes)
+ cancel()
+ },
+ }
relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
Ctx: ctx,
- ClientConn: &openAIWSClientFrameConn{conn: clientConn},
+ ClientConn: policyClientConn,
UpstreamConn: upstreamFrameConn,
FirstClientMessage: firstClientMessage,
Options: openaiwsv2.RelayOptions{
@@ -179,7 +397,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
},
Model: turn.RequestModel,
- ServiceTier: requestServiceTier,
+ ServiceTier: requestServiceTierPtr.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
@@ -227,7 +445,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
},
Model: relayResult.RequestModel,
- ServiceTier: requestServiceTier,
+ ServiceTier: requestServiceTierPtr.Load(),
Stream: true,
OpenAIWSMode: true,
ResponseHeaders: cloneHeader(handshakeHeaders),
diff --git a/backend/internal/service/ops_cleanup_service.go b/backend/internal/service/ops_cleanup_service.go
index 1cae6fe5..44ec1ad1 100644
--- a/backend/internal/service/ops_cleanup_service.go
+++ b/backend/internal/service/ops_cleanup_service.go
@@ -36,11 +36,15 @@ return 0
// - Scheduling: 5-field cron spec (minute hour dom month dow).
// - Multi-instance: best-effort Redis leader lock so only one node runs cleanup.
// - Safety: deletes in batches to avoid long transactions.
+//
+// 附带:在 runCleanupOnce 末尾调用 ChannelMonitorService.RunDailyMaintenance,
+// 统一共享 cron schedule + leader lock + heartbeat,避免再引一套调度。
type OpsCleanupService struct {
- opsRepo OpsRepository
- db *sql.DB
- redisClient *redis.Client
- cfg *config.Config
+ opsRepo OpsRepository
+ db *sql.DB
+ redisClient *redis.Client
+ cfg *config.Config
+ channelMonitorSvc *ChannelMonitorService
instanceID string
@@ -57,13 +61,15 @@ func NewOpsCleanupService(
db *sql.DB,
redisClient *redis.Client,
cfg *config.Config,
+ channelMonitorSvc *ChannelMonitorService,
) *OpsCleanupService {
return &OpsCleanupService{
- opsRepo: opsRepo,
- db: db,
- redisClient: redisClient,
- cfg: cfg,
- instanceID: uuid.NewString(),
+ opsRepo: opsRepo,
+ db: db,
+ redisClient: redisClient,
+ cfg: cfg,
+ channelMonitorSvc: channelMonitorSvc,
+ instanceID: uuid.NewString(),
}
}
@@ -178,6 +184,25 @@ func (c opsCleanupDeletedCounts) String() string {
)
}
+// opsCleanupPlan 把"保留天数"翻译成具体的清理动作。
+// - days < 0 → 跳过该项清理(ok=false),保留兼容老数据
+// - days == 0 → TRUNCATE TABLE(O(1) 全清),truncate=true
+// - days > 0 → 批量 DELETE 早于 now-N天 的行,cutoff = now - N 天
+//
+// 之所以 days==0 走 TRUNCATE 而非"now+24h cutoff + DELETE":
+// - 速度从 O(N) 降到 O(1),对百万行级表毫秒完成
+// - 无 WAL 写入、无后续 VACUUM 压力
+// - 这些 ops 表只有 cleanup 任务自己写,TRUNCATE 的 ACCESS EXCLUSIVE 锁影响可忽略
+func opsCleanupPlan(now time.Time, days int) (cutoff time.Time, truncate, ok bool) {
+ if days < 0 {
+ return time.Time{}, false, false
+ }
+ if days == 0 {
+ return time.Time{}, true, true
+ }
+ return now.AddDate(0, 0, -days), false, true
+}
+
func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) {
out := opsCleanupDeletedCounts{}
if s == nil || s.db == nil || s.cfg == nil {
@@ -188,34 +213,42 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
now := time.Now().UTC()
- // Error-like tables: error logs / retry attempts / alert events.
- if days := s.cfg.Ops.Cleanup.ErrorLogRetentionDays; days > 0 {
- cutoff := now.AddDate(0, 0, -days)
- n, err := deleteOldRowsByID(ctx, s.db, "ops_error_logs", "created_at", cutoff, batchSize, false)
+ // runOne 把"truncate? cutoff? batched delete?"封装到一处,
+ // 让三组清理(错误日志类 / 分钟指标 / 小时+日预聚合)调用方只关心表名和列名。
+ runOne := func(truncate bool, cutoff time.Time, table, timeCol string, castDate bool) (int64, error) {
+ if truncate {
+ return truncateOpsTable(ctx, s.db, table)
+ }
+ return deleteOldRowsByID(ctx, s.db, table, timeCol, cutoff, batchSize, castDate)
+ }
+
+ // Error-like tables: error logs / retry attempts / alert events / system logs / cleanup audits.
+ if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.ErrorLogRetentionDays); ok {
+ n, err := runOne(truncate, cutoff, "ops_error_logs", "created_at", false)
if err != nil {
return out, err
}
out.errorLogs = n
- n, err = deleteOldRowsByID(ctx, s.db, "ops_retry_attempts", "created_at", cutoff, batchSize, false)
+ n, err = runOne(truncate, cutoff, "ops_retry_attempts", "created_at", false)
if err != nil {
return out, err
}
out.retryAttempts = n
- n, err = deleteOldRowsByID(ctx, s.db, "ops_alert_events", "created_at", cutoff, batchSize, false)
+ n, err = runOne(truncate, cutoff, "ops_alert_events", "created_at", false)
if err != nil {
return out, err
}
out.alertEvents = n
- n, err = deleteOldRowsByID(ctx, s.db, "ops_system_logs", "created_at", cutoff, batchSize, false)
+ n, err = runOne(truncate, cutoff, "ops_system_logs", "created_at", false)
if err != nil {
return out, err
}
out.systemLogs = n
- n, err = deleteOldRowsByID(ctx, s.db, "ops_system_log_cleanup_audits", "created_at", cutoff, batchSize, false)
+ n, err = runOne(truncate, cutoff, "ops_system_log_cleanup_audits", "created_at", false)
if err != nil {
return out, err
}
@@ -223,9 +256,8 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
}
// Minute-level metrics snapshots.
- if days := s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays; days > 0 {
- cutoff := now.AddDate(0, 0, -days)
- n, err := deleteOldRowsByID(ctx, s.db, "ops_system_metrics", "created_at", cutoff, batchSize, false)
+ if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays); ok {
+ n, err := runOne(truncate, cutoff, "ops_system_metrics", "created_at", false)
if err != nil {
return out, err
}
@@ -233,21 +265,29 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
}
// Pre-aggregation tables (hourly/daily).
- if days := s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays; days > 0 {
- cutoff := now.AddDate(0, 0, -days)
- n, err := deleteOldRowsByID(ctx, s.db, "ops_metrics_hourly", "bucket_start", cutoff, batchSize, false)
+ if cutoff, truncate, ok := opsCleanupPlan(now, s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays); ok {
+ n, err := runOne(truncate, cutoff, "ops_metrics_hourly", "bucket_start", false)
if err != nil {
return out, err
}
out.hourlyPreagg = n
- n, err = deleteOldRowsByID(ctx, s.db, "ops_metrics_daily", "bucket_date", cutoff, batchSize, true)
+ n, err = runOne(truncate, cutoff, "ops_metrics_daily", "bucket_date", true)
if err != nil {
return out, err
}
out.dailyPreagg = n
}
+ // Channel monitor 每日维护(聚合昨日明细 + 软删过期明细/聚合)。
+ // 失败只记日志,不影响 ops 清理的成功状态(与 ops 各步骤风格一致);
+ // 维护本身已经把每步错误打到 slog,heartbeat result 不再分项记录。
+ if s.channelMonitorSvc != nil {
+ if err := s.channelMonitorSvc.RunDailyMaintenance(ctx); err != nil {
+ logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] channel monitor maintenance failed: %v", err)
+ }
+ }
+
return out, nil
}
@@ -288,7 +328,7 @@ WHERE id IN (SELECT id FROM batch)
res, err := db.ExecContext(ctx, q, cutoff, batchSize)
if err != nil {
// If ops tables aren't present yet (partial deployments), treat as no-op.
- if strings.Contains(strings.ToLower(err.Error()), "does not exist") && strings.Contains(strings.ToLower(err.Error()), "relation") {
+ if isMissingRelationError(err) {
return total, nil
}
return total, err
@@ -305,6 +345,46 @@ WHERE id IN (SELECT id FROM batch)
return total, nil
}
+// truncateOpsTable 用 TRUNCATE TABLE 清空指定表,先 SELECT COUNT(*) 取得清空前行数用于 heartbeat。
+//
+// 与 deleteOldRowsByID 的差异:
+// - 不可指定 WHERE 条件,仅用于 days==0 的"清空全部"语义
+// - O(1) 释放表的物理存储页,毫秒级完成,无 WAL 写入、无 VACUUM 压力
+// - 需要 ACCESS EXCLUSIVE 锁,但 ops 表只有清理任务自己写入,瞬间锁影响可忽略
+//
+// 表不存在(部分部署)静默返回 0,与 deleteOldRowsByID 保持一致。
+func truncateOpsTable(ctx context.Context, db *sql.DB, table string) (int64, error) {
+ if db == nil {
+ return 0, nil
+ }
+ var count int64
+ if err := db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", table)).Scan(&count); err != nil {
+ if isMissingRelationError(err) {
+ return 0, nil
+ }
+ return 0, fmt.Errorf("count %s: %w", table, err)
+ }
+ if count == 0 {
+ return 0, nil
+ }
+ if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE TABLE %s", table)); err != nil {
+ if isMissingRelationError(err) {
+ return 0, nil
+ }
+ return 0, fmt.Errorf("truncate %s: %w", table, err)
+ }
+ return count, nil
+}
+
+// isMissingRelationError 判断 PG 报错是否为"表不存在",用于让清理任务在部分部署场景静默跳过。
+func isMissingRelationError(err error) bool {
+ if err == nil {
+ return false
+ }
+ s := strings.ToLower(err.Error())
+ return strings.Contains(s, "does not exist") && strings.Contains(s, "relation")
+}
+
func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
if s == nil {
return nil, false
diff --git a/backend/internal/service/ops_cleanup_service_test.go b/backend/internal/service/ops_cleanup_service_test.go
new file mode 100644
index 00000000..86657d27
--- /dev/null
+++ b/backend/internal/service/ops_cleanup_service_test.go
@@ -0,0 +1,64 @@
+package service
+
+import (
+ "testing"
+ "time"
+)
+
+func TestOpsCleanupPlan(t *testing.T) {
+ now := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC)
+
+ cases := []struct {
+ name string
+ days int
+ wantOK bool
+ wantTruncate bool
+ wantCutoff time.Time
+ }{
+ {name: "negative skips", days: -1, wantOK: false},
+ {name: "zero truncates", days: 0, wantOK: true, wantTruncate: true},
+ {name: "positive yields past cutoff", days: 7, wantOK: true, wantCutoff: now.AddDate(0, 0, -7)},
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ cutoff, truncate, ok := opsCleanupPlan(now, tc.days)
+ if ok != tc.wantOK {
+ t.Fatalf("ok = %v, want %v", ok, tc.wantOK)
+ }
+ if !ok {
+ return
+ }
+ if truncate != tc.wantTruncate {
+ t.Fatalf("truncate = %v, want %v", truncate, tc.wantTruncate)
+ }
+ if !tc.wantTruncate && !cutoff.Equal(tc.wantCutoff) {
+ t.Fatalf("cutoff = %v, want %v", cutoff, tc.wantCutoff)
+ }
+ })
+ }
+}
+
+func TestIsMissingRelationError(t *testing.T) {
+ cases := []struct {
+ name string
+ err error
+ want bool
+ }{
+ {name: "nil is not missing", err: nil, want: false},
+ {name: "match relation does not exist", err: fakeErr(`pq: relation "ops_error_logs" does not exist`), want: true},
+ {name: "match case-insensitive", err: fakeErr(`ERROR: Relation "x" Does Not Exist`), want: true},
+ {name: "non-matching error", err: fakeErr("connection refused"), want: false},
+ }
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ if got := isMissingRelationError(tc.err); got != tc.want {
+ t.Fatalf("got %v, want %v", got, tc.want)
+ }
+ })
+ }
+}
+
+type fakeErr string
+
+func (e fakeErr) Error() string { return string(e) }
diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go
index c0e814ab..bd40d389 100644
--- a/backend/internal/service/ops_retry.go
+++ b/backend/internal/service/ops_retry.go
@@ -388,7 +388,7 @@ func (s *OpsService) executeRetry(ctx context.Context, errorLog *OpsErrorLogDeta
func detectOpsRetryType(path string) opsRetryRequestType {
p := strings.ToLower(strings.TrimSpace(path))
switch {
- case strings.Contains(p, "/responses"):
+ case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"):
return opsRetryTypeOpenAI
case strings.Contains(p, "/v1beta/"):
return opsRetryTypeGeminiV1B
diff --git a/backend/internal/service/ops_settings.go b/backend/internal/service/ops_settings.go
index 5871166c..ecc3a94b 100644
--- a/backend/internal/service/ops_settings.go
+++ b/backend/internal/service/ops_settings.go
@@ -387,13 +387,15 @@ func normalizeOpsAdvancedSettings(cfg *OpsAdvancedSettings) {
if cfg.DataRetention.CleanupSchedule == "" {
cfg.DataRetention.CleanupSchedule = "0 2 * * *"
}
- if cfg.DataRetention.ErrorLogRetentionDays <= 0 {
+ // 保留天数:0 表示每次定时清理全部(清空所有),> 0 表示按天数保留;
+ // 仅在拿到非法的负数时回填默认值,避免覆盖用户主动设的 0。
+ if cfg.DataRetention.ErrorLogRetentionDays < 0 {
cfg.DataRetention.ErrorLogRetentionDays = 30
}
- if cfg.DataRetention.MinuteMetricsRetentionDays <= 0 {
+ if cfg.DataRetention.MinuteMetricsRetentionDays < 0 {
cfg.DataRetention.MinuteMetricsRetentionDays = 30
}
- if cfg.DataRetention.HourlyMetricsRetentionDays <= 0 {
+ if cfg.DataRetention.HourlyMetricsRetentionDays < 0 {
cfg.DataRetention.HourlyMetricsRetentionDays = 30
}
// Normalize auto refresh interval (default 30 seconds)
@@ -406,14 +408,15 @@ func validateOpsAdvancedSettings(cfg *OpsAdvancedSettings) error {
if cfg == nil {
return errors.New("invalid config")
}
- if cfg.DataRetention.ErrorLogRetentionDays < 1 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
- return errors.New("error_log_retention_days must be between 1 and 365")
+ // 保留天数:0 表示每次清理全部,1-365 表示按天数保留。
+ if cfg.DataRetention.ErrorLogRetentionDays < 0 || cfg.DataRetention.ErrorLogRetentionDays > 365 {
+ return errors.New("error_log_retention_days must be between 0 and 365")
}
- if cfg.DataRetention.MinuteMetricsRetentionDays < 1 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
- return errors.New("minute_metrics_retention_days must be between 1 and 365")
+ if cfg.DataRetention.MinuteMetricsRetentionDays < 0 || cfg.DataRetention.MinuteMetricsRetentionDays > 365 {
+ return errors.New("minute_metrics_retention_days must be between 0 and 365")
}
- if cfg.DataRetention.HourlyMetricsRetentionDays < 1 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
- return errors.New("hourly_metrics_retention_days must be between 1 and 365")
+ if cfg.DataRetention.HourlyMetricsRetentionDays < 0 || cfg.DataRetention.HourlyMetricsRetentionDays > 365 {
+ return errors.New("hourly_metrics_retention_days must be between 0 and 365")
}
if cfg.AutoRefreshIntervalSec < 15 || cfg.AutoRefreshIntervalSec > 300 {
return errors.New("auto_refresh_interval_seconds must be between 15 and 300")
diff --git a/backend/internal/service/payment_config_limits.go b/backend/internal/service/payment_config_limits.go
index 56905278..973c601a 100644
--- a/backend/internal/service/payment_config_limits.go
+++ b/backend/internal/service/payment_config_limits.go
@@ -20,6 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return nil, fmt.Errorf("query provider instances: %w", err)
}
typeInstances := pcGroupByPaymentType(instances)
+ typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances)
resp := &MethodLimitsResponse{
Methods: make(map[string]MethodLimits, len(typeInstances)),
}
@@ -31,6 +32,41 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M
return resp, nil
}
+func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context.Context, typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance {
+ if len(typeInstances) == 0 {
+ return typeInstances
+ }
+
+ filtered := make(map[string][]*dbent.PaymentProviderInstance, len(typeInstances))
+ for paymentType, groupedInstances := range typeInstances {
+ filtered[paymentType] = groupedInstances
+ }
+
+ for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
+ matching := filterEnabledVisibleMethodInstances(instances, method)
+ providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
+ if err != nil {
+ delete(filtered, method)
+ continue
+ }
+ if providerKey == "" {
+ if len(matching) == 0 {
+ delete(filtered, method)
+ continue
+ }
+ filtered[method] = matching
+ continue
+ }
+ selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey)
+ if len(selectedInstances) == 0 {
+ delete(filtered, method)
+ continue
+ }
+ filtered[method] = selectedInstances
+ }
+ return filtered
+}
+
// GetMethodLimits returns per-payment-type limits from enabled provider instances.
func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
diff --git a/backend/internal/service/payment_config_limits_test.go b/backend/internal/service/payment_config_limits_test.go
index 73ad66ef..4df506d6 100644
--- a/backend/internal/service/payment_config_limits_test.go
+++ b/backend/internal/service/payment_config_limits_test.go
@@ -1,10 +1,12 @@
package service
import (
+ "context"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/stretchr/testify/require"
)
func TestUnionFloat(t *testing.T) {
@@ -299,3 +301,161 @@ func TestPcInstanceTypeLimits(t *testing.T) {
}
})
}
+
+func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) {
+ tests := []struct {
+ name string
+ sourceSetting string
+ wantAlipaySingleMin float64
+ wantAlipaySingleMax float64
+ wantGlobalMin float64
+ wantGlobalMax float64
+ }{
+ {
+ name: "official source",
+ sourceSetting: VisibleMethodSourceOfficialAlipay,
+ wantAlipaySingleMin: 10,
+ wantAlipaySingleMax: 100,
+ wantGlobalMin: 10,
+ wantGlobalMax: 300,
+ },
+ {
+ name: "easypay source",
+ sourceSetting: VisibleMethodSourceEasyPayAlipay,
+ wantAlipaySingleMin: 20,
+ wantAlipaySingleMax: 200,
+ wantGlobalMin: 20,
+ wantGlobalMax: 300,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Official Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official alipay instance: %v", err)
+ }
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`).
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay alipay instance: %v", err)
+ }
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("Official WeChat").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official wxpay instance: %v", err)
+ }
+
+ svc := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ SettingPaymentVisibleMethodAlipaySource: tt.sourceSetting,
+ },
+ },
+ }
+
+ resp, err := svc.GetAvailableMethodLimits(ctx)
+ if err != nil {
+ t.Fatalf("GetAvailableMethodLimits returned error: %v", err)
+ }
+
+ alipayLimits, ok := resp.Methods[payment.TypeAlipay]
+ if !ok {
+ t.Fatalf("expected alipay limits to remain visible, got %v", resp.Methods)
+ }
+ if alipayLimits.SingleMin != tt.wantAlipaySingleMin || alipayLimits.SingleMax != tt.wantAlipaySingleMax {
+ t.Fatalf("alipay limits = %+v, want min=%v max=%v", alipayLimits, tt.wantAlipaySingleMin, tt.wantAlipaySingleMax)
+ }
+
+ wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
+ if !ok {
+ t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods)
+ }
+ if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 {
+ t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits)
+ }
+ if resp.GlobalMin != tt.wantGlobalMin || resp.GlobalMax != tt.wantGlobalMax {
+ t.Fatalf("global range = (%v, %v), want (%v, %v)", resp.GlobalMin, resp.GlobalMax, tt.wantGlobalMin, tt.wantGlobalMax)
+ }
+ })
+ }
+}
+
+func TestGetAvailableMethodLimitsPreservesLegacyCrossProviderBehaviorWhenVisibleMethodSourceMissing(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Official Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`).
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay Mixed").
+ SetConfig("{}").
+ SetSupportedTypes("alipay,wxpay").
+ SetLimits(`{"alipay":{"singleMin":20,"singleMax":200},"wxpay":{"singleMin":40,"singleMax":400}}`).
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("Official WeChat").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`).
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{}},
+ }
+
+ resp, err := svc.GetAvailableMethodLimits(ctx)
+ require.NoError(t, err)
+
+ alipayLimits, ok := resp.Methods[payment.TypeAlipay]
+ require.True(t, ok, "expected alipay limits to remain visible")
+ require.Equal(t, 10.0, alipayLimits.SingleMin)
+ require.Equal(t, 200.0, alipayLimits.SingleMax)
+
+ wxpayLimits, ok := resp.Methods[payment.TypeWxpay]
+ require.True(t, ok, "expected wxpay limits to remain visible")
+ require.Equal(t, 30.0, wxpayLimits.SingleMin)
+ require.Equal(t, 400.0, wxpayLimits.SingleMax)
+
+ require.Equal(t, 10.0, resp.GlobalMin)
+ require.Equal(t, 400.0, resp.GlobalMax)
+}
diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go
index 3c406b45..ff05e559 100644
--- a/backend/internal/service/payment_config_providers.go
+++ b/backend/internal/service/payment_config_providers.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "log/slog"
"strconv"
"strings"
@@ -11,9 +12,22 @@ import (
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/payment/provider"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
+// validateProviderConfig runs the provider's constructor to surface config-level
+// errors at save time (e.g. wxpay missing certSerial), instead of only failing
+// when an order is created. Returns the structured ApplicationError from the
+// constructor so the frontend i18n layer can localize it.
+//
+// Only validates enabled instances — a disabled instance may be a half-filled
+// draft the admin will complete later.
+func (s *PaymentConfigService) validateProviderConfig(providerKey string, config map[string]string) error {
+ _, err := provider.CreateProvider(providerKey, "_validate_", config)
+ return err
+}
+
// --- Provider Instance CRUD ---
func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) {
@@ -47,11 +61,10 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
resp := ProviderInstanceResponse{
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
- Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled,
- AllowUserRefund: inst.AllowUserRefund,
- SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
+ Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, AllowUserRefund: inst.AllowUserRefund,
+ SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
}
- resp.Config, err = s.decryptAndMaskConfig(inst.Config)
+ resp.Config, err = s.decryptAndMaskConfig(inst.ProviderKey, inst.Config)
if err != nil {
return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err)
}
@@ -60,8 +73,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
return result, nil
}
-func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) {
- return s.decryptConfig(encrypted)
+// decryptAndMaskConfig returns the stored config with sensitive fields omitted.
+// Admin UIs display masked placeholders for these; the raw values never leave
+// the server. Callers that need the full config (e.g. payment runtime) must
+// use decryptConfig directly.
+func (s *PaymentConfigService) decryptAndMaskConfig(providerKey, encrypted string) (map[string]string, error) {
+ cfg, err := s.decryptConfig(encrypted)
+ if err != nil {
+ return nil, err
+ }
+ if cfg == nil {
+ return nil, nil
+ }
+ masked := make(map[string]string, len(cfg))
+ for k, v := range cfg {
+ if isSensitiveProviderConfigField(providerKey, k) {
+ continue
+ }
+ masked[k] = v
+ }
+ return masked, nil
}
// pendingOrderStatuses are order statuses considered "in progress".
@@ -71,18 +102,62 @@ var pendingOrderStatuses = []string{
payment.OrderStatusRecharging,
}
-var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"}
+// providerSensitiveConfigFields is the authoritative list of config keys that
+// are treated as secrets per provider. Must stay in sync with the frontend
+// definition at frontend/src/components/payment/providerConfig.ts
+// (PROVIDER_CONFIG_FIELDS, fields with sensitive: true).
+//
+// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl,
+// stripe publishableKey) are returned in plaintext by the admin GET API.
+var providerSensitiveConfigFields = map[string]map[string]struct{}{
+ payment.TypeEasyPay: {"pkey": {}},
+ payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
+ payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
+ payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
+}
-func isSensitiveConfigField(fieldName string) bool {
- lower := strings.ToLower(fieldName)
- for _, p := range sensitiveConfigPatterns {
- if strings.Contains(lower, p) {
+// providerPendingOrderProtectedConfigFields lists config keys that cannot be
+// changed while the instance has in-progress orders. This includes secrets plus
+// all provider identity fields that are snapshotted into orders or used by
+// webhook/refund verification.
+var providerPendingOrderProtectedConfigFields = map[string]map[string]struct{}{
+ payment.TypeEasyPay: {"pkey": {}, "pid": {}},
+ payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}, "appid": {}},
+ payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}, "appid": {}, "mpappid": {}, "mchid": {}, "publickeyid": {}, "certserial": {}},
+ payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
+}
+
+func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
+ fields, ok := providerSensitiveConfigFields[providerKey]
+ if !ok {
+ return false
+ }
+ _, found := fields[strings.ToLower(fieldName)]
+ return found
+}
+
+func hasPendingOrderProtectedConfigChange(providerKey string, currentConfig, nextConfig map[string]string) bool {
+ fields, ok := providerPendingOrderProtectedConfigFields[providerKey]
+ if !ok {
+ return false
+ }
+ for fieldName := range fields {
+ if providerConfigFieldValue(currentConfig, fieldName) != providerConfigFieldValue(nextConfig, fieldName) {
return true
}
}
return false
}
+func providerConfigFieldValue(config map[string]string, fieldName string) string {
+ for key, value := range config {
+ if strings.EqualFold(key, fieldName) {
+ return value
+ }
+ }
+ return ""
+}
+
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
return s.entClient.PaymentOrder.Query().
Where(
@@ -108,6 +183,14 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
if err := validateProviderRequest(req.ProviderKey, req.Name, typesStr); err != nil {
return nil, err
}
+ if err := s.validateVisibleMethodEnablementConflicts(ctx, 0, req.ProviderKey, typesStr, req.Enabled); err != nil {
+ return nil, err
+ }
+ if req.Enabled {
+ if err := s.validateProviderConfig(req.ProviderKey, req.Config); err != nil {
+ return nil, err
+ }
+ }
enc, err := s.encryptConfig(req.Config)
if err != nil {
return nil, err
@@ -136,18 +219,47 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error {
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update
// boilerplate and pending-order safety checks.
func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
- if req.Config != nil {
- hasSensitive := false
- for k := range req.Config {
- if isSensitiveConfigField(k) && req.Config[k] != "" {
- hasSensitive = true
- break
- }
+ current, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
+ if err != nil {
+ return nil, fmt.Errorf("load provider instance: %w", err)
+ }
+ var pendingOrderCount *int
+ getPendingOrderCount := func() (int, error) {
+ if pendingOrderCount != nil {
+ return *pendingOrderCount, nil
}
- if hasSensitive {
- count, err := s.countPendingOrders(ctx, id)
+ count, err := s.countPendingOrders(ctx, id)
+ if err != nil {
+ return 0, fmt.Errorf("check pending orders: %w", err)
+ }
+ pendingOrderCount = &count
+ return count, nil
+ }
+ nextEnabled := current.Enabled
+ if req.Enabled != nil {
+ nextEnabled = *req.Enabled
+ }
+ nextSupportedTypes := current.SupportedTypes
+ if req.SupportedTypes != nil {
+ nextSupportedTypes = joinTypes(req.SupportedTypes)
+ }
+ if err := s.validateVisibleMethodEnablementConflicts(ctx, id, current.ProviderKey, nextSupportedTypes, nextEnabled); err != nil {
+ return nil, err
+ }
+ var mergedConfig map[string]string
+ if req.Config != nil {
+ currentConfig, err := s.decryptConfig(current.Config)
+ if err != nil {
+ return nil, fmt.Errorf("decrypt existing config: %w", err)
+ }
+ mergedConfig, err = s.mergeConfig(ctx, id, req.Config)
+ if err != nil {
+ return nil, err
+ }
+ if hasPendingOrderProtectedConfigChange(current.ProviderKey, currentConfig, mergedConfig) {
+ count, err := getPendingOrderCount()
if err != nil {
- return nil, fmt.Errorf("check pending orders: %w", err)
+ return nil, err
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
@@ -156,25 +268,40 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
}
if req.Enabled != nil && !*req.Enabled {
- count, err := s.countPendingOrders(ctx, id)
+ count, err := getPendingOrderCount()
if err != nil {
- return nil, fmt.Errorf("check pending orders: %w", err)
+ return nil, err
}
if count > 0 {
return nil, infraerrors.Conflict("PENDING_ORDERS", "instance has pending orders").
WithMetadata(map[string]string{"count": strconv.Itoa(count)})
}
}
+ // Validate merged config when the instance will end up enabled.
+ // This surfaces provider-level errors (e.g. wxpay missing certSerial) at save time,
+ // so admins see them in the dialog instead of only when an order is created.
+ finalEnabled := current.Enabled
+ if req.Enabled != nil {
+ finalEnabled = *req.Enabled
+ }
+ if finalEnabled {
+ configToValidate := mergedConfig
+ if configToValidate == nil {
+ configToValidate, err = s.decryptConfig(current.Config)
+ if err != nil {
+ return nil, fmt.Errorf("decrypt existing config: %w", err)
+ }
+ }
+ if err := s.validateProviderConfig(current.ProviderKey, configToValidate); err != nil {
+ return nil, err
+ }
+ }
u := s.entClient.PaymentProviderInstance.UpdateOneID(id)
if req.Name != nil {
u.SetName(*req.Name)
}
- if req.Config != nil {
- merged, err := s.mergeConfig(ctx, id, req.Config)
- if err != nil {
- return nil, err
- }
- enc, err := s.encryptConfig(merged)
+ if mergedConfig != nil {
+ enc, err := s.encryptConfig(mergedConfig)
if err != nil {
return nil, err
}
@@ -182,17 +309,13 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
if req.SupportedTypes != nil {
// Check pending orders before removing payment types
- count, err := s.countPendingOrders(ctx, id)
+ count, err := getPendingOrderCount()
if err != nil {
- return nil, fmt.Errorf("check pending orders: %w", err)
+ return nil, err
}
if count > 0 {
// Load current instance to compare types
- inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
- if err != nil {
- return nil, fmt.Errorf("load provider instance: %w", err)
- }
- oldTypes := strings.Split(inst.SupportedTypes, ",")
+ oldTypes := strings.Split(current.SupportedTypes, ",")
newTypes := req.SupportedTypes
for _, ot := range oldTypes {
ot = strings.TrimSpace(ot)
@@ -237,10 +360,7 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
if req.RefundEnabled != nil {
refundEnabled = *req.RefundEnabled
} else {
- inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
- if err == nil {
- refundEnabled = inst.RefundEnabled
- }
+ refundEnabled = current.RefundEnabled
}
if refundEnabled {
u.SetAllowUserRefund(true)
@@ -282,27 +402,48 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon
return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err)
}
if existing == nil {
- return newConfig, nil
+ existing = map[string]string{}
}
for k, v := range newConfig {
+ // Preserve existing secrets when the client submits an empty value
+ // (admin UI omits the value to indicate "leave unchanged").
+ if v == "" && isSensitiveProviderConfigField(inst.ProviderKey, k) {
+ continue
+ }
existing[k] = v
}
return existing, nil
}
-func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) {
- if encrypted == "" {
+// decryptConfig parses a stored provider config.
+// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext
+// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including
+// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty,
+// letting the admin re-enter the config via the UI to complete the migration.
+//
+// TODO(deprecated-legacy-ciphertext): The AES fallback branch is a transitional
+// shim for pre-plaintext records. Remove it (and the encryptionKey field) after
+// a few releases once all live deployments have re-saved their provider configs.
+func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) {
+ if stored == "" {
return nil, nil
}
- decrypted, err := payment.Decrypt(encrypted, s.encryptionKey)
- if err != nil {
- return nil, fmt.Errorf("decrypt config: %w", err)
+ var cfg map[string]string
+ if err := json.Unmarshal([]byte(stored), &cfg); err == nil {
+ return cfg, nil
}
- var raw map[string]string
- if err := json.Unmarshal([]byte(decrypted), &raw); err != nil {
- return nil, fmt.Errorf("unmarshal decrypted config: %w", err)
+ // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
+ if len(s.encryptionKey) == payment.AES256KeySize {
+ //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
+ if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil {
+ if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil {
+ return cfg, nil
+ }
+ }
}
- return raw, nil
+ slog.Warn("payment provider config unreadable, treating as empty for re-entry",
+ "stored_len", len(stored))
+ return nil, nil
}
func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error {
@@ -317,14 +458,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in
return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx)
}
+// encryptConfig serialises a provider config for storage.
+// New records are written as plaintext JSON; the historical AES-GCM wrapping
+// has been dropped but decryptConfig still accepts old ciphertext during migration.
func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) {
data, err := json.Marshal(cfg)
if err != nil {
return "", fmt.Errorf("marshal config: %w", err)
}
- enc, err := payment.Encrypt(string(data), s.encryptionKey)
- if err != nil {
- return "", fmt.Errorf("encrypt config: %w", err)
- }
- return enc, nil
+ return string(data), nil
}
diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go
index 2aaa874f..e0d2908a 100644
--- a/backend/internal/service/payment_config_providers_test.go
+++ b/backend/internal/service/payment_config_providers_test.go
@@ -3,8 +3,18 @@
package service
import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/pem"
+ "strconv"
"testing"
+ "time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -97,41 +107,52 @@ func TestValidateProviderRequest(t *testing.T) {
}
}
-func TestIsSensitiveConfigField(t *testing.T) {
+func TestIsSensitiveProviderConfigField(t *testing.T) {
t.Parallel()
tests := []struct {
- field string
- wantSen bool
+ providerKey string
+ field string
+ wantSen bool
}{
- // Sensitive fields (contain key/secret/private/password/pkey patterns)
- {"secretKey", true},
- {"apiSecret", true},
- {"pkey", true},
- {"privateKey", true},
- {"apiPassword", true},
- {"appKey", true},
- {"SECRET_TOKEN", true},
- {"PrivateData", true},
- {"PASSWORD", true},
- {"mySecretValue", true},
+ // Stripe: publishableKey is public, only secretKey/webhookSecret are secrets
+ {"stripe", "secretKey", true},
+ {"stripe", "webhookSecret", true},
+ {"stripe", "SecretKey", true}, // case-insensitive
+ {"stripe", "publishableKey", false},
+ {"stripe", "appId", false},
- // Non-sensitive fields
- {"appId", false},
- {"mchId", false},
- {"apiBase", false},
- {"endpoint", false},
- {"merchantNo", false},
- {"paymentMode", false},
- {"notifyUrl", false},
+ // Alipay
+ {"alipay", "privateKey", true},
+ {"alipay", "publicKey", true},
+ {"alipay", "alipayPublicKey", true},
+ {"alipay", "appId", false},
+ {"alipay", "notifyUrl", false},
+
+ // Wxpay
+ {"wxpay", "privateKey", true},
+ {"wxpay", "apiV3Key", true},
+ {"wxpay", "publicKey", true},
+ {"wxpay", "publicKeyId", false},
+ {"wxpay", "certSerial", false},
+ {"wxpay", "mchId", false},
+
+ // EasyPay
+ {"easypay", "pkey", true},
+ {"easypay", "pid", false},
+ {"easypay", "apiBase", false},
+
+ // Unknown provider: never sensitive
+ {"unknown", "secretKey", false},
}
for _, tc := range tests {
- t.Run(tc.field, func(t *testing.T) {
+ tc := tc
+ t.Run(tc.providerKey+"/"+tc.field, func(t *testing.T) {
t.Parallel()
- got := isSensitiveConfigField(tc.field)
- assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field)
+ got := isSensitiveProviderConfigField(tc.providerKey, tc.field)
+ assert.Equal(t, tc.wantSen, got, "isSensitiveProviderConfigField(%q, %q)", tc.providerKey, tc.field)
})
}
}
@@ -185,3 +206,403 @@ func TestJoinTypes(t *testing.T) {
})
}
}
+
+func TestCreateProviderInstanceAllowsVisibleMethodProvidersFromDifferentSources(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ svc := &PaymentConfigService{
+ entClient: client,
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ }
+
+ _, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: "easypay",
+ Name: "EasyPay Alipay",
+ Config: map[string]string{
+ "pid": "1001",
+ "pkey": "pkey-1001",
+ "apiBase": "https://pay.example.com",
+ "notifyUrl": "https://merchant.example.com/notify",
+ "returnUrl": "https://merchant.example.com/return",
+ },
+ SupportedTypes: []string{"alipay"},
+ Enabled: true,
+ })
+ require.NoError(t, err)
+
+ _, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: "alipay",
+ Name: "Official Alipay",
+ Config: map[string]string{"appId": "app-1", "privateKey": "private-key"},
+ SupportedTypes: []string{"alipay"},
+ Enabled: true,
+ })
+ require.NoError(t, err)
+}
+
+func TestUpdateProviderInstanceAllowsEnablingVisibleMethodProviderFromDifferentSource(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ svc := &PaymentConfigService{
+ entClient: client,
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ }
+
+ existing, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: "easypay",
+ Name: "EasyPay WeChat",
+ Config: map[string]string{
+ "pid": "2001",
+ "pkey": "pkey-2001",
+ "apiBase": "https://pay.example.com",
+ "notifyUrl": "https://merchant.example.com/notify",
+ "returnUrl": "https://merchant.example.com/return",
+ },
+ SupportedTypes: []string{"wxpay"},
+ Enabled: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, existing)
+
+ candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: "wxpay",
+ Name: "Official WeChat",
+ Config: validWxpayProviderConfig(t),
+ SupportedTypes: []string{"wxpay"},
+ Enabled: false,
+ })
+ require.NoError(t, err)
+
+ _, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{
+ Enabled: boolPtrValue(true),
+ })
+ require.NoError(t, err)
+}
+
+func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ svc := &PaymentConfigService{
+ entClient: client,
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ }
+
+ instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: "easypay",
+ Name: "EasyPay",
+ Config: map[string]string{
+ "pid": "3001",
+ "pkey": "pkey-3001",
+ "apiBase": "https://pay.example.com",
+ "notifyUrl": "https://merchant.example.com/notify",
+ "returnUrl": "https://merchant.example.com/return",
+ },
+ SupportedTypes: []string{"alipay"},
+ Enabled: false,
+ })
+ require.NoError(t, err)
+
+ _, err = svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
+ Enabled: boolPtrValue(true),
+ SupportedTypes: []string{"alipay", "wxpay"},
+ })
+ require.NoError(t, err)
+
+ saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
+ require.NoError(t, err)
+ require.True(t, saved.Enabled)
+ require.Equal(t, "alipay,wxpay", saved.SupportedTypes)
+}
+
+func TestUpdateProviderInstanceRejectsProtectedConfigChangesWhilePendingOrders(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ providerKey string
+ createConfig func(*testing.T) map[string]string
+ supportedType []string
+ updateConfig map[string]string
+ fieldName string
+ wantValue string
+ }{
+ {
+ name: "wxpay appId",
+ providerKey: payment.TypeWxpay,
+ createConfig: validWxpayProviderConfig,
+ supportedType: []string{payment.TypeWxpay},
+ updateConfig: map[string]string{"appId": "wx-app-updated"},
+ fieldName: "appId",
+ wantValue: "wx-app-test",
+ },
+ {
+ name: "wxpay mpAppId",
+ providerKey: payment.TypeWxpay,
+ createConfig: validWxpayProviderConfigWithJSAPIAppID,
+ supportedType: []string{payment.TypeWxpay},
+ updateConfig: map[string]string{"mpAppId": "wx-mp-app-updated"},
+ fieldName: "mpAppId",
+ wantValue: "wx-mp-app-test",
+ },
+ {
+ name: "wxpay mchId",
+ providerKey: payment.TypeWxpay,
+ createConfig: validWxpayProviderConfig,
+ supportedType: []string{payment.TypeWxpay},
+ updateConfig: map[string]string{"mchId": "mch-updated"},
+ fieldName: "mchId",
+ wantValue: "mch-test",
+ },
+ {
+ name: "wxpay publicKeyId",
+ providerKey: payment.TypeWxpay,
+ createConfig: validWxpayProviderConfig,
+ supportedType: []string{payment.TypeWxpay},
+ updateConfig: map[string]string{"publicKeyId": "public-key-id-updated"},
+ fieldName: "publicKeyId",
+ wantValue: "public-key-id-test",
+ },
+ {
+ name: "wxpay certSerial",
+ providerKey: payment.TypeWxpay,
+ createConfig: validWxpayProviderConfig,
+ supportedType: []string{payment.TypeWxpay},
+ updateConfig: map[string]string{"certSerial": "cert-serial-updated"},
+ fieldName: "certSerial",
+ wantValue: "cert-serial-test",
+ },
+ {
+ name: "alipay appId",
+ providerKey: payment.TypeAlipay,
+ createConfig: validAlipayProviderConfig,
+ supportedType: []string{payment.TypeAlipay},
+ updateConfig: map[string]string{"appId": "alipay-app-updated"},
+ fieldName: "appId",
+ wantValue: "alipay-app-test",
+ },
+ {
+ name: "easypay pid",
+ providerKey: payment.TypeEasyPay,
+ createConfig: validEasyPayProviderConfig,
+ supportedType: []string{payment.TypeAlipay},
+ updateConfig: map[string]string{"pid": "pid-updated"},
+ fieldName: "pid",
+ wantValue: "pid-test",
+ },
+ }
+
+ for _, tc := range tests {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ svc := &PaymentConfigService{
+ entClient: client,
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ }
+
+ instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: tc.providerKey,
+ Name: "protected-config-instance",
+ Config: tc.createConfig(t),
+ SupportedTypes: tc.supportedType,
+ Enabled: true,
+ })
+ require.NoError(t, err)
+
+ createPendingProviderConfigOrder(t, ctx, client, instance)
+
+ updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
+ Config: tc.updateConfig,
+ })
+ require.Nil(t, updated)
+ require.Error(t, err)
+ require.Equal(t, "PENDING_ORDERS", infraerrors.Reason(err))
+
+ saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
+ require.NoError(t, err)
+ cfg, err := svc.decryptConfig(saved.Config)
+ require.NoError(t, err)
+ require.Equal(t, tc.wantValue, cfg[tc.fieldName])
+ })
+ }
+}
+
+func TestUpdateProviderInstanceAllowsSafeConfigChangesWhilePendingOrders(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ providerKey string
+ createConfig func(*testing.T) map[string]string
+ supportedType []string
+ updateConfig map[string]string
+ fieldName string
+ wantValue string
+ }{
+ {
+ name: "wxpay notifyUrl",
+ providerKey: payment.TypeWxpay,
+ createConfig: validWxpayProviderConfig,
+ supportedType: []string{payment.TypeWxpay},
+ updateConfig: map[string]string{"notifyUrl": "https://merchant.example.com/wxpay/notify-v2"},
+ fieldName: "notifyUrl",
+ wantValue: "https://merchant.example.com/wxpay/notify-v2",
+ },
+ {
+ name: "alipay same appId",
+ providerKey: payment.TypeAlipay,
+ createConfig: validAlipayProviderConfig,
+ supportedType: []string{payment.TypeAlipay},
+ updateConfig: map[string]string{"appId": "alipay-app-test"},
+ fieldName: "appId",
+ wantValue: "alipay-app-test",
+ },
+ }
+
+ for _, tc := range tests {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ svc := &PaymentConfigService{
+ entClient: client,
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ }
+
+ instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{
+ ProviderKey: tc.providerKey,
+ Name: "safe-config-instance",
+ Config: tc.createConfig(t),
+ SupportedTypes: tc.supportedType,
+ Enabled: true,
+ })
+ require.NoError(t, err)
+
+ createPendingProviderConfigOrder(t, ctx, client, instance)
+
+ updated, err := svc.UpdateProviderInstance(ctx, instance.ID, UpdateProviderInstanceRequest{
+ Config: tc.updateConfig,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+
+ saved, err := client.PaymentProviderInstance.Get(ctx, instance.ID)
+ require.NoError(t, err)
+ cfg, err := svc.decryptConfig(saved.Config)
+ require.NoError(t, err)
+ require.Equal(t, tc.wantValue, cfg[tc.fieldName])
+ })
+ }
+}
+
+func createPendingProviderConfigOrder(t *testing.T, ctx context.Context, client *dbent.Client, instance *dbent.PaymentProviderInstance) {
+ t.Helper()
+
+ user, err := client.User.Create().
+ SetEmail("provider-config-pending@example.com").
+ SetPasswordHash("hash").
+ SetUsername("provider-config-pending-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ instanceID := strconv.FormatInt(instance.ID, 10)
+ _, err = client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("PENDING-PROVIDER-CONFIG-" + instanceID).
+ SetOutTradeNo("sub2_pending_provider_config_" + instanceID).
+ SetPaymentType(providerPendingOrderPaymentType(instance.ProviderKey)).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID(instanceID).
+ SetProviderKey(instance.ProviderKey).
+ Save(ctx)
+ require.NoError(t, err)
+}
+
+func providerPendingOrderPaymentType(providerKey string) string {
+ switch providerKey {
+ case payment.TypeWxpay:
+ return payment.TypeWxpay
+ case payment.TypeAlipay:
+ return payment.TypeAlipay
+ default:
+ return payment.TypeAlipay
+ }
+}
+
+func boolPtrValue(v bool) *bool {
+ return &v
+}
+
+func validAlipayProviderConfig(t *testing.T) map[string]string {
+ t.Helper()
+
+ return map[string]string{
+ "appId": "alipay-app-test",
+ "privateKey": "alipay-private-key-test",
+ "notifyUrl": "https://merchant.example.com/alipay/notify",
+ "returnUrl": "https://merchant.example.com/alipay/return",
+ }
+}
+
+func validEasyPayProviderConfig(t *testing.T) map[string]string {
+ t.Helper()
+
+ return map[string]string{
+ "pid": "pid-test",
+ "pkey": "pkey-test",
+ "apiBase": "https://pay.example.com",
+ "notifyUrl": "https://merchant.example.com/easypay/notify",
+ "returnUrl": "https://merchant.example.com/easypay/return",
+ }
+}
+
+func validWxpayProviderConfig(t *testing.T) map[string]string {
+ t.Helper()
+
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ privDER, err := x509.MarshalPKCS8PrivateKey(key)
+ require.NoError(t, err)
+ pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
+ require.NoError(t, err)
+
+ return map[string]string{
+ "appId": "wx-app-test",
+ "mchId": "mch-test",
+ "privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
+ "apiV3Key": "12345678901234567890123456789012",
+ "publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})),
+ "publicKeyId": "public-key-id-test",
+ "certSerial": "cert-serial-test",
+ }
+}
+
+func validWxpayProviderConfigWithJSAPIAppID(t *testing.T) map[string]string {
+ t.Helper()
+
+ cfg := validWxpayProviderConfig(t)
+ cfg["mpAppId"] = "wx-mp-app-test"
+ return cfg
+}
diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go
index 59764b29..02d061ae 100644
--- a/backend/internal/service/payment_config_service.go
+++ b/backend/internal/service/payment_config_service.go
@@ -93,6 +93,11 @@ type UpdatePaymentConfigRequest struct {
CancelRateLimitWindow *int `json:"cancel_rate_limit_window"`
CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"`
CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"`
+
+ VisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
+ VisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"`
+ VisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"`
+ VisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"`
}
// MethodLimits holds per-payment-type limits.
@@ -196,6 +201,8 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
SettingHelpImageURL, SettingHelpText,
SettingCancelRateLimitOn, SettingCancelRateLimitMax,
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
+ SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource,
+ SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource,
}
vals, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
@@ -234,18 +241,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
}
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
+ types := make([]string, 0, len(strings.Split(raw, ",")))
for _, t := range strings.Split(raw, ",") {
t = strings.TrimSpace(t)
if t != "" {
- cfg.EnabledTypes = append(cfg.EnabledTypes, t)
+ types = append(types, t)
}
}
+ cfg.EnabledTypes = NormalizeVisibleMethods(types)
}
return cfg
}
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
+ if s.entClient == nil {
+ return ""
+ }
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.EnabledEQ(true),
@@ -282,25 +294,29 @@ func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req Upda
}
}
m := map[string]string{
- SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled),
- SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount),
- SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount),
- SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit),
- SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin),
- SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders),
- SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled),
- SettingBalanceRechargeMult: formatPositiveFloat(req.BalanceRechargeMultiplier),
- SettingRechargeFeeRate: formatNonNegativeFloat(req.RechargeFeeRate),
- SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy),
- SettingProductNamePrefix: derefStr(req.ProductNamePrefix),
- SettingProductNameSuffix: derefStr(req.ProductNameSuffix),
- SettingHelpImageURL: derefStr(req.HelpImageURL),
- SettingHelpText: derefStr(req.HelpText),
- SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled),
- SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax),
- SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
- SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
- SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
+ SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled),
+ SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount),
+ SettingMaxRechargeAmount: formatPositiveFloat(req.MaxAmount),
+ SettingDailyRechargeLimit: formatPositiveFloat(req.DailyLimit),
+ SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin),
+ SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders),
+ SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled),
+ SettingBalanceRechargeMult: formatPositiveFloat(req.BalanceRechargeMultiplier),
+ SettingRechargeFeeRate: formatNonNegativeFloat(req.RechargeFeeRate),
+ SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy),
+ SettingProductNamePrefix: derefStr(req.ProductNamePrefix),
+ SettingProductNameSuffix: derefStr(req.ProductNameSuffix),
+ SettingHelpImageURL: derefStr(req.HelpImageURL),
+ SettingHelpText: derefStr(req.HelpText),
+ SettingCancelRateLimitOn: formatBoolOrEmpty(req.CancelRateLimitEnabled),
+ SettingCancelRateLimitMax: formatPositiveInt(req.CancelRateLimitMax),
+ SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
+ SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
+ SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
+ SettingPaymentVisibleMethodAlipaySource: derefStr(req.VisibleMethodAlipaySource),
+ SettingPaymentVisibleMethodWxpaySource: derefStr(req.VisibleMethodWxpaySource),
+ SettingPaymentVisibleMethodAlipayEnabled: formatBoolOrEmpty(req.VisibleMethodAlipayEnabled),
+ SettingPaymentVisibleMethodWxpayEnabled: formatBoolOrEmpty(req.VisibleMethodWxpayEnabled),
}
if req.EnabledTypes != nil {
m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",")
@@ -385,3 +401,79 @@ func pcParseInt(s string, defaultVal int) int {
}
return v
}
+
+func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool {
+ available := make(map[string]bool, 4)
+ for _, inst := range instances {
+ switch inst.ProviderKey {
+ case payment.TypeAlipay:
+ if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) {
+ available[VisibleMethodSourceOfficialAlipay] = true
+ }
+ case payment.TypeWxpay:
+ if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) {
+ available[VisibleMethodSourceOfficialWechat] = true
+ }
+ case payment.TypeEasyPay:
+ for _, supportedType := range splitTypes(inst.SupportedTypes) {
+ switch NormalizeVisibleMethod(supportedType) {
+ case payment.TypeAlipay:
+ available[VisibleMethodSourceEasyPayAlipay] = true
+ case payment.TypeWxpay:
+ available[VisibleMethodSourceEasyPayWechat] = true
+ }
+ }
+ }
+ }
+ return available
+}
+
+func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string {
+ shouldExpose := map[string]bool{
+ payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available),
+ payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available),
+ }
+
+ seen := make(map[string]struct{}, len(base)+2)
+ out := make([]string, 0, len(base)+2)
+ appendType := func(paymentType string) {
+ paymentType = NormalizeVisibleMethod(paymentType)
+ if paymentType == "" {
+ return
+ }
+ if _, ok := seen[paymentType]; ok {
+ return
+ }
+ seen[paymentType] = struct{}{}
+ out = append(out, paymentType)
+ }
+
+ for _, paymentType := range base {
+ visibleMethod := NormalizeVisibleMethod(paymentType)
+ switch visibleMethod {
+ case payment.TypeAlipay, payment.TypeWxpay:
+ if shouldExpose[visibleMethod] {
+ appendType(visibleMethod)
+ }
+ default:
+ appendType(visibleMethod)
+ }
+ }
+
+ for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} {
+ if shouldExpose[visibleMethod] {
+ appendType(visibleMethod)
+ }
+ }
+ return out
+}
+
+func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool {
+ enabledKey := visibleMethodEnabledSettingKey(method)
+ sourceKey := visibleMethodSourceSettingKey(method)
+ if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" {
+ return false
+ }
+ source := NormalizeVisibleMethodSource(method, vals[sourceKey])
+ return source != "" && available[source]
+}
diff --git a/backend/internal/service/payment_config_service_test.go b/backend/internal/service/payment_config_service_test.go
index 027bb796..f04f4697 100644
--- a/backend/internal/service/payment_config_service_test.go
+++ b/backend/internal/service/payment_config_service_test.go
@@ -1,9 +1,19 @@
package service
import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
"testing"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
)
func TestPcParseFloat(t *testing.T) {
@@ -163,6 +173,20 @@ func TestParsePaymentConfig(t *testing.T) {
}
})
+ t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) {
+ t.Parallel()
+ vals := map[string]string{
+ SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay",
+ }
+ cfg := svc.parsePaymentConfig(vals)
+ if len(cfg.EnabledTypes) != 2 {
+ t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes))
+ }
+ if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" {
+ t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes)
+ }
+ })
+
t.Run("empty enabled types string", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
@@ -204,3 +228,210 @@ func TestGetBasePaymentType(t *testing.T) {
})
}
}
+
+func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) {
+ t.Parallel()
+
+ base := []string{"alipay", "wxpay", "stripe"}
+ vals := map[string]string{
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
+ SettingPaymentVisibleMethodWxpayEnabled: "true",
+ SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
+ }
+ available := map[string]bool{
+ VisibleMethodSourceOfficialAlipay: true,
+ VisibleMethodSourceOfficialWechat: false,
+ }
+
+ got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
+ want := []string{"alipay", "stripe"}
+ if len(got) != len(want) {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
+ }
+ }
+}
+
+func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) {
+ t.Parallel()
+
+ base := []string{"stripe"}
+ vals := map[string]string{
+ SettingPaymentVisibleMethodAlipayEnabled: "true",
+ SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay,
+ }
+ available := map[string]bool{
+ VisibleMethodSourceEasyPayAlipay: true,
+ }
+
+ got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
+ want := []string{"stripe", "alipay"}
+ if len(got) != len(want) {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
+ }
+ }
+}
+
+func TestBuildVisibleMethodSourceAvailability(t *testing.T) {
+ t.Parallel()
+
+ instances := []*dbent.PaymentProviderInstance{
+ {ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"},
+ {ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"},
+ {ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"},
+ }
+
+ got := buildVisibleMethodSourceAvailability(instances)
+ if !got[VisibleMethodSourceOfficialAlipay] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay)
+ }
+ if !got[VisibleMethodSourceEasyPayAlipay] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay)
+ }
+ if !got[VisibleMethodSourceOfficialWechat] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat)
+ }
+ if !got[VisibleMethodSourceEasyPayWechat] {
+ t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat)
+ }
+}
+
+func TestGetPaymentConfigKeepsStoredEnabledTypes(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay instance: %v", err)
+ }
+
+ svc := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ SettingEnabledPaymentTypes: "alipay,wxpay,stripe",
+ },
+ },
+ }
+
+ cfg, err := svc.GetPaymentConfig(ctx)
+ if err != nil {
+ t.Fatalf("GetPaymentConfig returned error: %v", err)
+ }
+
+ want := []string{payment.TypeAlipay, payment.TypeWxpay, payment.TypeStripe}
+ if len(cfg.EnabledTypes) != len(want) {
+ t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes)
+ }
+ for i := range want {
+ if cfg.EnabledTypes[i] != want[i] {
+ t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes)
+ }
+ }
+}
+
+func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ dbName := fmt.Sprintf(
+ "file:%s?mode=memory&cache=shared",
+ strings.NewReplacer("/", "_", " ", "_").Replace(t.Name()),
+ )
+ db, err := sql.Open("sqlite", dbName)
+ if err != nil {
+ t.Fatalf("open sqlite: %v", err)
+ }
+ t.Cleanup(func() { _ = db.Close() })
+
+ if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
+ t.Fatalf("enable foreign keys: %v", err)
+ }
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
+
+type paymentConfigSettingRepoStub struct {
+ values map[string]string
+ updates map[string]string
+}
+
+func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) {
+ return nil, nil
+}
+func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ return s.values[key], nil
+}
+func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil }
+func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ out[key] = s.values[key]
+ }
+ return out, nil
+}
+func (s *paymentConfigSettingRepoStub) SetMultiple(_ context.Context, values map[string]string) error {
+ s.updates = make(map[string]string, len(values))
+ for key, value := range values {
+ s.updates[key] = value
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ return s.values, nil
+}
+func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil }
+
+func TestUpdatePaymentConfig_PersistsVisibleMethodRouting(t *testing.T) {
+ repo := &paymentConfigSettingRepoStub{values: map[string]string{}}
+ svc := &PaymentConfigService{settingRepo: repo}
+
+ alipayEnabled := true
+ wxpayEnabled := false
+ err := svc.UpdatePaymentConfig(context.Background(), UpdatePaymentConfigRequest{
+ VisibleMethodAlipayEnabled: &alipayEnabled,
+ VisibleMethodAlipaySource: paymentConfigStrPtr(VisibleMethodSourceEasyPayAlipay),
+ VisibleMethodWxpayEnabled: &wxpayEnabled,
+ VisibleMethodWxpaySource: paymentConfigStrPtr(VisibleMethodSourceOfficialWechat),
+ })
+ if err != nil {
+ t.Fatalf("UpdatePaymentConfig returned error: %v", err)
+ }
+
+ if repo.values[SettingPaymentVisibleMethodAlipayEnabled] != "true" {
+ t.Fatalf("alipay enabled = %q, want true", repo.values[SettingPaymentVisibleMethodAlipayEnabled])
+ }
+ if repo.values[SettingPaymentVisibleMethodAlipaySource] != VisibleMethodSourceEasyPayAlipay {
+ t.Fatalf("alipay source = %q, want %q", repo.values[SettingPaymentVisibleMethodAlipaySource], VisibleMethodSourceEasyPayAlipay)
+ }
+ if repo.values[SettingPaymentVisibleMethodWxpayEnabled] != "false" {
+ t.Fatalf("wxpay enabled = %q, want false", repo.values[SettingPaymentVisibleMethodWxpayEnabled])
+ }
+ if repo.values[SettingPaymentVisibleMethodWxpaySource] != VisibleMethodSourceOfficialWechat {
+ t.Fatalf("wxpay source = %q, want %q", repo.values[SettingPaymentVisibleMethodWxpaySource], VisibleMethodSourceOfficialWechat)
+ }
+}
+
+func paymentConfigStrPtr(value string) *string {
+ return &value
+}
diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go
index 44818b37..5df69aea 100644
--- a/backend/internal/service/payment_fulfillment.go
+++ b/backend/internal/service/payment_fulfillment.go
@@ -2,6 +2,8 @@ package service
import (
"context"
+ "encoding/json"
+ "errors"
"fmt"
"log/slog"
"math"
@@ -16,6 +18,14 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
+// ErrOrderNotFound is returned by HandlePaymentNotification when the webhook
+// references an out_trade_no that does not exist in our DB. Callers (webhook
+// handlers) should treat this as a terminal, non-retryable condition and still
+// respond with a 2xx success to the provider — otherwise the provider will keep
+// retrying forever (e.g. when a foreign environment's webhook endpoint is
+// misconfigured to point at us, or when our orders table has been wiped).
+var ErrOrderNotFound = errors.New("payment order not found")
+
// --- Payment Notification & Fulfillment ---
func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payment.PaymentNotification, pk string) error {
@@ -25,37 +35,102 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
// Look up order by out_trade_no (the external order ID we sent to the provider)
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx)
if err != nil {
- // Fallback: try legacy format (sub2_N where N is DB ID)
- trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
- if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
- return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
+ // Fallback only for true legacy "sub2_N" DB-ID payloads when the
+ // current out_trade_no lookup genuinely did not find an order.
+ if oid, ok := parseLegacyPaymentOrderID(n.OrderID, err); ok {
+ return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk, n.Metadata)
}
- return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
+ if dbent.IsNotFound(err) {
+ return fmt.Errorf("%w: out_trade_no=%s", ErrOrderNotFound, n.OrderID)
+ }
+ return fmt.Errorf("lookup order failed for out_trade_no %s: %w", n.OrderID, err)
}
- return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk)
+ return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk, n.Metadata)
}
-func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error {
+func parseLegacyPaymentOrderID(orderID string, lookupErr error) (int64, bool) {
+ if !dbent.IsNotFound(lookupErr) {
+ return 0, false
+ }
+ orderID = strings.TrimSpace(orderID)
+ if !strings.HasPrefix(orderID, orderIDPrefix) {
+ return 0, false
+ }
+ trimmed := strings.TrimPrefix(orderID, orderIDPrefix)
+ if trimmed == "" || trimmed == orderID {
+ return 0, false
+ }
+ oid, err := strconv.ParseInt(trimmed, 10, 64)
+ if err != nil || oid <= 0 {
+ return 0, false
+ }
+ return oid, true
+}
+
+func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string, metadata map[string]string) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid)
if err != nil {
slog.Error("order not found", "orderID", oid)
return nil
}
- // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
- // Also skip if paid is NaN/Inf (malformed provider data).
- if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
- if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
- s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
- return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
- }
+ instanceProviderKey := ""
+ if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil {
+ instanceProviderKey = inst.ProviderKey
}
- // Use order's expected amount when provider didn't report one
- if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) {
- paid = o.PayAmount
+ expectedProviderKey := expectedNotificationProviderKeyForOrder(s.registry, o, instanceProviderKey)
+ if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{
+ "expectedProvider": expectedProviderKey,
+ "actualProvider": pk,
+ "tradeNo": tradeNo,
+ })
+ return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk)
+ }
+ if err := validateProviderNotificationMetadata(o, pk, metadata); err != nil {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_METADATA_MISMATCH", pk, map[string]any{
+ "detail": err.Error(),
+ "tradeNo": tradeNo,
+ })
+ return err
+ }
+ if !isValidProviderAmount(paid) {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", pk, map[string]any{
+ "expected": o.PayAmount,
+ "paid": paid,
+ "tradeNo": tradeNo,
+ })
+ return fmt.Errorf("invalid paid amount from provider: %v", paid)
+ }
+ if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
+ return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
+func isValidProviderAmount(amount float64) bool {
+ return amount > 0 && !math.IsNaN(amount) && !math.IsInf(amount, 0)
+}
+
+func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
+ return validateProviderSnapshotMetadata(order, providerKey, metadata)
+}
+
+func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string {
+ if key := strings.TrimSpace(instanceProviderKey); key != "" {
+ return key
+ }
+ if key := strings.TrimSpace(orderProviderKey); key != "" {
+ return key
+ }
+ if registry != nil {
+ if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(orderPaymentType))); key != "" {
+ return key
+ }
+ }
+ return strings.TrimSpace(orderPaymentType)
+}
+
func (s *PaymentService) toPaid(ctx context.Context, o *dbent.PaymentOrder, tradeNo string, paid float64, pk string) error {
previousStatus := o.Status
now := time.Now()
@@ -194,6 +269,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
switch action {
case redeemActionSkipCompleted:
+ if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
+ return err
+ }
// Code already created and redeemed — just mark completed
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
case redeemActionCreate:
@@ -207,6 +285,9 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e
if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil {
return fmt.Errorf("redeem balance: %w", err)
}
+ if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil {
+ return err
+ }
return s.markCompleted(ctx, o, "RECHARGE_SUCCESS")
}
@@ -284,6 +365,142 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action
return c > 0
}
+func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) error {
+ if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 {
+ return nil
+ }
+ if s.affiliateService == nil {
+ return nil
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": fmt.Sprintf("begin affiliate rebate tx: %v", err),
+ })
+ return fmt.Errorf("begin affiliate rebate tx: %w", err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ claimed, err := s.tryClaimAffiliateRebateAudit(txCtx, tx.Client(), o.ID, o.Amount)
+ if err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": err.Error(),
+ })
+ return fmt.Errorf("claim affiliate rebate audit: %w", err)
+ }
+ if !claimed {
+ return nil
+ }
+
+ rebateAmount, err := s.affiliateService.AccrueInviteRebate(txCtx, o.UserID, o.Amount)
+ if err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": err.Error(),
+ })
+ return fmt.Errorf("accrue affiliate rebate: %w", err)
+ }
+
+ if rebateAmount <= 0 {
+ if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_SKIPPED", map[string]any{
+ "baseAmount": o.Amount,
+ "reason": "no inviter bound or rebate amount <= 0",
+ }); err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": err.Error(),
+ })
+ return fmt.Errorf("update affiliate rebate skipped audit: %w", err)
+ }
+ if err := tx.Commit(); err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
+ })
+ return fmt.Errorf("commit affiliate rebate tx: %w", err)
+ }
+ return nil
+ }
+
+ if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{
+ "baseAmount": o.Amount,
+ "rebateAmount": rebateAmount,
+ }); err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": err.Error(),
+ })
+ return fmt.Errorf("update affiliate rebate applied audit: %w", err)
+ }
+
+ if err := tx.Commit(); err != nil {
+ s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{
+ "error": fmt.Sprintf("commit affiliate rebate tx: %v", err),
+ })
+ return fmt.Errorf("commit affiliate rebate tx: %w", err)
+ }
+ return nil
+}
+
+func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) {
+ if client == nil {
+ return false, errors.New("nil payment client")
+ }
+ oid := strconv.FormatInt(orderID, 10)
+ detail, _ := json.Marshal(map[string]any{
+ "baseAmount": baseAmount,
+ "status": "reserved",
+ })
+ rows, err := client.QueryContext(ctx, `
+INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
+SELECT $1::text, 'AFFILIATE_REBATE_APPLIED', $2::text, 'system', NOW()
+WHERE NOT EXISTS (
+ SELECT 1
+ FROM payment_audit_logs
+ WHERE order_id = $1::text
+ AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
+)
+ON CONFLICT (order_id, action) DO NOTHING
+RETURNING id`, oid, string(detail))
+ if err != nil {
+ return false, err
+ }
+ defer func() { _ = rows.Close() }()
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return false, err
+ }
+ return false, nil
+ }
+ var claimID int64
+ if err := rows.Scan(&claimID); err != nil {
+ return false, err
+ }
+ return true, nil
+}
+
+func (s *PaymentService) updateClaimedAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, action string, detail map[string]any) error {
+ if client == nil {
+ return errors.New("nil payment client")
+ }
+ oid := strconv.FormatInt(orderID, 10)
+ detailJSON, _ := json.Marshal(detail)
+ updated, err := client.PaymentAuditLog.Update().
+ Where(
+ paymentauditlog.OrderIDEQ(oid),
+ paymentauditlog.ActionEQ("AFFILIATE_REBATE_APPLIED"),
+ ).
+ SetAction(action).
+ SetDetail(string(detailJSON)).
+ SetOperator("system").
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ if updated == 0 {
+ return errors.New("affiliate rebate claim log not found")
+ }
+ return nil
+}
+
func (s *PaymentService) markFailed(ctx context.Context, oid int64, cause error) {
now := time.Now()
r := psErrMsg(cause)
diff --git a/backend/internal/service/payment_fulfillment_order_not_found_test.go b/backend/internal/service/payment_fulfillment_order_not_found_test.go
new file mode 100644
index 00000000..f6787e29
--- /dev/null
+++ b/backend/internal/service/payment_fulfillment_order_not_found_test.go
@@ -0,0 +1,106 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "testing"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/stretchr/testify/require"
+)
+
+// newOrderNotFoundTestClient wires an in-memory sqlite-backed ent.Client so
+// tests can exercise HandlePaymentNotification's real DB lookup path without
+// standing up a service stack.
+func newOrderNotFoundTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:payment_order_not_found?mode=memory&cache=shared&_fk=1")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
+
+// TestHandlePaymentNotification_UnknownOrder_ReturnsSentinel exercises the
+// happy-path of the webhook 404 fix: when the notification references an
+// out_trade_no that does not exist in our DB, HandlePaymentNotification must
+// return an error that errors.Is(err, ErrOrderNotFound) recognizes. The
+// webhook handler relies on that contract to ack with a 2xx so the provider
+// stops retrying.
+func TestHandlePaymentNotification_UnknownOrder_ReturnsSentinel(t *testing.T) {
+ ctx := context.Background()
+ client := newOrderNotFoundTestClient(t)
+
+ svc := &PaymentService{
+ entClient: client,
+ providersLoaded: true,
+ }
+
+ notification := &payment.PaymentNotification{
+ OrderID: "sub2_does_not_exist_12345",
+ TradeNo: "stripe_evt_test_xyz",
+ Status: payment.NotificationStatusSuccess,
+ Amount: 1000,
+ }
+
+ err := svc.HandlePaymentNotification(ctx, notification, payment.TypeStripe)
+ require.Error(t, err, "unknown out_trade_no should surface an error")
+ require.ErrorIs(t, err, ErrOrderNotFound,
+ "webhook handler relies on errors.Is(err, ErrOrderNotFound) to downgrade to 200")
+
+ // Sanity: the wrapped error message should still include the out_trade_no
+ // for operator diagnostics.
+ require.Contains(t, err.Error(), notification.OrderID)
+}
+
+// TestHandlePaymentNotification_NonSuccessStatus_Skips documents the
+// short-circuit that precedes the DB lookup: when the notification is not a
+// success event (e.g. Stripe non-payment events that reach us via the webhook
+// route), we return nil without touching the DB and the handler responds 200.
+func TestHandlePaymentNotification_NonSuccessStatus_Skips(t *testing.T) {
+ ctx := context.Background()
+ client := newOrderNotFoundTestClient(t)
+
+ svc := &PaymentService{
+ entClient: client,
+ providersLoaded: true,
+ }
+
+ notification := &payment.PaymentNotification{
+ OrderID: "sub2_does_not_exist_12345",
+ Status: "failed", // any value other than NotificationStatusSuccess
+ }
+
+ err := svc.HandlePaymentNotification(ctx, notification, payment.TypeStripe)
+ require.NoError(t, err,
+ "non-success notifications must short-circuit before the DB lookup")
+}
+
+// TestErrOrderNotFound_DistinctFromOtherErrors guards against an accidental
+// collapse where a generic wrapped error would start matching ErrOrderNotFound
+// (which would silently mask real DB failures).
+func TestErrOrderNotFound_DistinctFromOtherErrors(t *testing.T) {
+ genericErr := errors.New("some other failure")
+ require.False(t, errors.Is(genericErr, ErrOrderNotFound))
+ require.False(t, errors.Is(ErrOrderNotFound, genericErr))
+
+ wrappedLookupErr := errors.New("lookup order failed for out_trade_no sub2_42: connection refused")
+ require.False(t, errors.Is(wrappedLookupErr, ErrOrderNotFound),
+ "DB connection failures must not masquerade as order-not-found")
+}
diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go
index 625b0d9f..abdb59de 100644
--- a/backend/internal/service/payment_fulfillment_test.go
+++ b/backend/internal/service/payment_fulfillment_test.go
@@ -3,12 +3,39 @@
package service
import (
+ "context"
"errors"
+ "math"
"testing"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/stretchr/testify/assert"
)
+type paymentFulfillmentTestProvider struct {
+ key string
+ supportedTypes []payment.PaymentType
+}
+
+func (p paymentFulfillmentTestProvider) Name() string { return p.key }
+func (p paymentFulfillmentTestProvider) ProviderKey() string { return p.key }
+func (p paymentFulfillmentTestProvider) SupportedTypes() []payment.PaymentType {
+ return p.supportedTypes
+}
+func (p paymentFulfillmentTestProvider) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+func (p paymentFulfillmentTestProvider) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
+ panic("unexpected call")
+}
+func (p paymentFulfillmentTestProvider) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+func (p paymentFulfillmentTestProvider) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
// ---------------------------------------------------------------------------
// resolveRedeemAction — pure idempotency decision logic
// ---------------------------------------------------------------------------
@@ -161,3 +188,181 @@ func TestResolveRedeemAction_IsUsedCanUseConsistency(t *testing.T) {
assert.True(t, unusedCode.CanUse())
assert.Equal(t, redeemActionRedeem, resolveRedeemAction(unusedCode, nil))
}
+
+func TestExpectedNotificationProviderKeyPrefersOrderInstanceProvider(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeAlipay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKey(registry, payment.TypeAlipay, "", payment.TypeEasyPay),
+ )
+}
+
+func TestExpectedNotificationProviderKeyUsesRegistryMappingForLegacyOrders(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeEasyPay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKey(registry, payment.TypeAlipay, "", ""),
+ )
+}
+
+func TestExpectedNotificationProviderKeyFallsBackToPaymentType(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t,
+ payment.TypeWxpay,
+ expectedNotificationProviderKey(nil, payment.TypeWxpay, "", ""),
+ )
+}
+
+func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeAlipay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""),
+ )
+}
+
+func TestExpectedNotificationProviderKeyForOrderUsesSnapshotProviderKey(t *testing.T) {
+ t.Parallel()
+
+ registry := payment.NewRegistry()
+ registry.Register(paymentFulfillmentTestProvider{
+ key: payment.TypeAlipay,
+ supportedTypes: []payment.PaymentType{payment.TypeAlipay},
+ })
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_key": payment.TypeEasyPay,
+ },
+ }
+
+ assert.Equal(t,
+ payment.TypeEasyPay,
+ expectedNotificationProviderKeyForOrder(registry, order, ""),
+ )
+}
+
+func TestValidateProviderNotificationMetadataRejectsWxpaySnapshotMismatch(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "merchant_app_id": "wx-app-expected",
+ "merchant_id": "mch-expected",
+ "currency": "CNY",
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
+ "appid": "wx-app-other",
+ "mchid": "mch-expected",
+ "currency": "CNY",
+ "trade_state": "SUCCESS",
+ })
+ assert.ErrorContains(t, err, "wxpay appid mismatch")
+}
+
+func TestValidateProviderNotificationMetadataAllowsLegacyOrdersWithoutSnapshotFields(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": "9",
+ "provider_key": payment.TypeWxpay,
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeWxpay, map[string]string{
+ "appid": "wx-app-runtime",
+ "mchid": "mch-runtime",
+ "currency": "CNY",
+ "trade_state": "SUCCESS",
+ })
+ assert.NoError(t, err)
+}
+
+func TestParseLegacyPaymentOrderID(t *testing.T) {
+ t.Parallel()
+
+ oid, ok := parseLegacyPaymentOrderID("sub2_42", &dbent.NotFoundError{})
+ assert.True(t, ok)
+ assert.EqualValues(t, 42, oid)
+
+ _, ok = parseLegacyPaymentOrderID("42", &dbent.NotFoundError{})
+ assert.False(t, ok)
+
+ _, ok = parseLegacyPaymentOrderID("sub2_42", errors.New("db down"))
+ assert.False(t, ok)
+}
+
+func TestIsValidProviderAmount(t *testing.T) {
+ t.Parallel()
+
+ assert.True(t, isValidProviderAmount(0.01))
+ assert.False(t, isValidProviderAmount(0))
+ assert.False(t, isValidProviderAmount(-1))
+ assert.False(t, isValidProviderAmount(math.NaN()))
+ assert.False(t, isValidProviderAmount(math.Inf(1)))
+}
+
+func TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 2,
+ "merchant_app_id": "alipay-app-expected",
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeAlipay, map[string]string{
+ "app_id": "alipay-app-other",
+ })
+ assert.ErrorContains(t, err, "alipay app_id mismatch")
+}
+
+func TestValidateProviderNotificationMetadataRejectsEasyPaySnapshotMismatch(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 2,
+ "merchant_id": "pid-expected",
+ },
+ }
+
+ err := validateProviderNotificationMetadata(order, payment.TypeEasyPay, map[string]string{
+ "pid": "pid-other",
+ })
+ assert.ErrorContains(t, err, "easypay pid mismatch")
+}
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 128416e4..15d4509d 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -2,9 +2,11 @@ package service
import (
"context"
+ "errors"
"fmt"
"log/slog"
"math"
+ "net/url"
"strconv"
"strings"
"time"
@@ -22,6 +24,9 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
if req.OrderType == "" {
req.OrderType = payment.OrderTypeBalance
}
+ if normalized := NormalizeVisibleMethod(req.PaymentType); normalized != "" {
+ req.PaymentType = normalized
+ }
cfg, err := s.configService.GetPaymentConfig(ctx)
if err != nil {
return nil, fmt.Errorf("get payment config: %w", err)
@@ -54,11 +59,25 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
feeRate := cfg.RechargeFeeRate
payAmountStr := payment.CalculatePayAmount(limitAmount, feeRate)
payAmount, _ := strconv.ParseFloat(payAmountStr, 64)
- order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount)
+ sel, err := s.selectCreateOrderInstance(ctx, req, cfg, payAmount)
if err != nil {
return nil, err
}
- resp, err := s.invokeProvider(ctx, order, req, cfg, limitAmount, payAmountStr, payAmount, plan)
+ if err := s.validateSelectedCreateOrderInstance(ctx, req, sel); err != nil {
+ return nil, err
+ }
+ oauthResp, err := s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, limitAmount, payAmount, feeRate, sel)
+ if err != nil {
+ return nil, err
+ }
+ if oauthResp != nil {
+ return oauthResp, nil
+ }
+ order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount, sel)
+ if err != nil {
+ return nil, err
+ }
+ resp, err := s.invokeProvider(ctx, order, req, cfg, limitAmount, payAmountStr, payAmount, plan, sel)
if err != nil {
_, _ = s.entClient.PaymentOrder.UpdateOneID(order.ID).
SetStatus(OrderStatusFailed).
@@ -103,7 +122,7 @@ func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRe
return plan, nil
}
-func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64) (*dbent.PaymentOrder, error) {
+func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64, sel *payment.InstanceSelection) (*dbent.PaymentOrder, error) {
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("begin transaction: %w", err)
@@ -120,6 +139,17 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
+ outTradeNo, err := s.allocateOutTradeNo(ctx, tx)
+ if err != nil {
+ return nil, err
+ }
+ providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req)
+ selectedInstanceID := ""
+ selectedProviderKey := ""
+ if sel != nil {
+ selectedInstanceID = strings.TrimSpace(sel.InstanceID)
+ selectedProviderKey = strings.TrimSpace(sel.ProviderKey)
+ }
b := tx.PaymentOrder.Create().
SetUserID(req.UserID).
SetUserEmail(user.Email).
@@ -129,7 +159,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
SetPayAmount(payAmount).
SetFeeRate(feeRate).
SetRechargeCode("").
- SetOutTradeNo(generateOutTradeNo()).
+ SetOutTradeNo(outTradeNo).
SetPaymentType(req.PaymentType).
SetPaymentTradeNo("").
SetOrderType(req.OrderType).
@@ -140,6 +170,15 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
if req.SrcURL != "" {
b.SetSrcURL(req.SrcURL)
}
+ if selectedInstanceID != "" {
+ b.SetProviderInstanceID(selectedInstanceID)
+ }
+ if selectedProviderKey != "" {
+ b.SetProviderKey(selectedProviderKey)
+ }
+ if providerSnapshot != nil {
+ b.SetProviderSnapshot(providerSnapshot)
+ }
if plan != nil {
b.SetPlanID(plan.ID).SetSubscriptionGroupID(plan.GroupID).SetSubscriptionDays(psComputeValidityDays(plan.ValidityDays, plan.ValidityUnit))
}
@@ -158,6 +197,21 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
return order, nil
}
+func (s *PaymentService) allocateOutTradeNo(ctx context.Context, tx *dbent.Tx) (string, error) {
+ const maxAttempts = 5
+ for attempt := 0; attempt < maxAttempts; attempt++ {
+ candidate := generateOutTradeNo()
+ exists, err := tx.PaymentOrder.Query().Where(paymentorder.OutTradeNo(candidate)).Exist(ctx)
+ if err != nil {
+ return "", fmt.Errorf("check out_trade_no uniqueness: %w", err)
+ }
+ if !exists {
+ return candidate, nil
+ }
+ }
+ return "", fmt.Errorf("generate unique out_trade_no: exhausted %d attempts", maxAttempts)
+}
+
func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, userID int64, max int) error {
if max <= 0 {
max = defaultMaxPendingOrders
@@ -167,12 +221,71 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return fmt.Errorf("count pending orders: %w", err)
}
if c >= max {
- return infraerrors.TooManyRequests("TOO_MANY_PENDING", fmt.Sprintf("too many pending orders (max %d)", max)).
+ return infraerrors.TooManyRequests("TOO_MANY_PENDING", "too_many_pending").
WithMetadata(map[string]string{"max": strconv.Itoa(max)})
}
return nil
}
+func buildPaymentOrderProviderSnapshot(sel *payment.InstanceSelection, req CreateOrderRequest) map[string]any {
+ if sel == nil {
+ return nil
+ }
+
+ snapshot := map[string]any{}
+ snapshot["schema_version"] = 2
+
+ instanceID := strings.TrimSpace(sel.InstanceID)
+ if instanceID != "" {
+ snapshot["provider_instance_id"] = instanceID
+ }
+
+ providerKey := strings.TrimSpace(sel.ProviderKey)
+ if providerKey != "" {
+ snapshot["provider_key"] = providerKey
+ }
+
+ paymentMode := strings.TrimSpace(sel.PaymentMode)
+ if paymentMode != "" {
+ snapshot["payment_mode"] = paymentMode
+ }
+
+ if providerKey == payment.TypeWxpay {
+ if merchantAppID := paymentOrderSnapshotWxpayAppID(sel, req); merchantAppID != "" {
+ snapshot["merchant_app_id"] = merchantAppID
+ }
+ if merchantID := strings.TrimSpace(sel.Config["mchId"]); merchantID != "" {
+ snapshot["merchant_id"] = merchantID
+ }
+ snapshot["currency"] = "CNY"
+ }
+ if providerKey == payment.TypeAlipay {
+ if merchantAppID := strings.TrimSpace(sel.Config["appId"]); merchantAppID != "" {
+ snapshot["merchant_app_id"] = merchantAppID
+ }
+ }
+ if providerKey == payment.TypeEasyPay {
+ if merchantID := strings.TrimSpace(sel.Config["pid"]); merchantID != "" {
+ snapshot["merchant_id"] = merchantID
+ }
+ }
+
+ if len(snapshot) == 1 {
+ return nil
+ }
+ return snapshot
+}
+
+func paymentOrderSnapshotWxpayAppID(sel *payment.InstanceSelection, req CreateOrderRequest) string {
+ if sel == nil || strings.TrimSpace(sel.ProviderKey) != payment.TypeWxpay {
+ return ""
+ }
+ if strings.TrimSpace(req.OpenID) != "" {
+ return strings.TrimSpace(provider.ResolveWxpayJSAPIAppID(sel.Config))
+ }
+ return strings.TrimSpace(sel.Config["appId"])
+}
+
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 {
return nil
@@ -191,33 +304,127 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
used += o.Amount
}
if used+amount > limit {
- return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", fmt.Sprintf("daily recharge limit reached, remaining: %.2f", math.Max(0, limit-used)))
+ return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily_limit_exceeded").
+ WithMetadata(map[string]string{"remaining": fmt.Sprintf("%.2f", math.Max(0, limit-used))})
}
return nil
}
-func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) {
- // Select an instance across all providers that support the requested payment type.
- // This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay").
- sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
+func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req CreateOrderRequest, cfg *PaymentConfig, payAmount float64) (*payment.InstanceSelection, error) {
+ selectCtx, err := s.prepareCreateOrderSelectionContext(ctx, req)
if err != nil {
- return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
+ return nil, err
+ }
+ sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
+ if err != nil {
+ return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "method_not_configured").
+ WithMetadata(map[string]string{"payment_type": req.PaymentType})
}
if sel == nil {
- return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance")
+ return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no_available_instance")
}
+ return sel, nil
+}
+
+func (s *PaymentService) prepareCreateOrderSelectionContext(ctx context.Context, req CreateOrderRequest) (context.Context, error) {
+ if !requestNeedsWeChatJSAPICompatibility(req) {
+ return ctx, nil
+ }
+ if !s.usesOfficialWxpayVisibleMethod(ctx) {
+ return ctx, nil
+ }
+ expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return payment.WithWxpayJSAPIAppID(ctx, expectedAppID), nil
+}
+
+func requestNeedsWeChatJSAPICompatibility(req CreateOrderRequest) bool {
+ if payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
+ return false
+ }
+ return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != ""
+}
+
+func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) bool {
+ if s == nil || s.configService == nil {
+ return false
+ }
+ inst, err := s.configService.resolveEnabledVisibleMethodInstance(ctx, payment.TypeWxpay)
+ if err != nil {
+ return false
+ }
+ if inst == nil {
+ return false
+ }
+ return inst.ProviderKey == payment.TypeWxpay
+}
+
+func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) {
prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config)
if err != nil {
- return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable")
+ slog.Error("[PaymentService] CreateProvider failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
+ // If the provider returned a structured ApplicationError (e.g. WXPAY_CONFIG_MISSING_KEY),
+ // pass it through with provider context added to metadata. Otherwise wrap as PAYMENT_PROVIDER_MISCONFIGURED.
+ if appErr := new(infraerrors.ApplicationError); errors.As(err, &appErr) {
+ md := map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID}
+ for k, v := range appErr.Metadata {
+ md[k] = v
+ }
+ return nil, appErr.WithMetadata(md)
+ }
+ return nil, infraerrors.ServiceUnavailable("PAYMENT_PROVIDER_MISCONFIGURED", "provider_misconfigured").
+ WithMetadata(map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID})
}
subject := s.buildPaymentSubject(plan, limitAmount, cfg)
outTradeNo := order.OutTradeNo
- pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes})
+ canonicalReturnURL, err := CanonicalizeReturnURL(req.ReturnURL, req.SrcHost, req.SrcURL)
+ if err != nil {
+ return nil, err
+ }
+ resumeToken := ""
+ if resume := s.paymentResume(); resume != nil {
+ if canonicalReturnURL != "" && resume.isSigningConfigured() {
+ resumeToken, err = resume.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: order.UserID,
+ ProviderInstanceID: sel.InstanceID,
+ ProviderKey: sel.ProviderKey,
+ PaymentType: req.PaymentType,
+ CanonicalReturnURL: canonicalReturnURL,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("create payment resume token: %w", err)
+ }
+ }
+ }
+ providerReturnURL, err := buildPaymentReturnURL(canonicalReturnURL, order.ID, outTradeNo, resumeToken)
+ if err != nil {
+ return nil, err
+ }
+ providerReq := buildProviderCreatePaymentRequest(CreateOrderRequest{
+ PaymentType: req.PaymentType,
+ OpenID: req.OpenID,
+ ClientIP: req.ClientIP,
+ IsMobile: req.IsMobile,
+ ReturnURL: providerReturnURL,
+ }, sel, outTradeNo, payAmountStr, subject)
+ pr, err := prov.CreatePayment(ctx, providerReq)
if err != nil {
slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
- return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error()))
+ if appErr := new(infraerrors.ApplicationError); errors.As(err, &appErr) {
+ return nil, appErr
+ }
+ return nil, classifyCreatePaymentError(req, sel.ProviderKey, err)
}
- _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx)
+ _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).
+ SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).
+ SetNillablePayURL(psNilIfEmpty(pr.PayURL)).
+ SetNillableQrCode(psNilIfEmpty(pr.QRCode)).
+ SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).
+ SetNillableProviderKey(psNilIfEmpty(sel.ProviderKey)).
+ Save(ctx)
if err != nil {
return nil, fmt.Errorf("update order with payment details: %w", err)
}
@@ -227,8 +434,36 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
"payAmount": order.PayAmount,
"paymentType": req.PaymentType,
"orderType": req.OrderType,
+ "paymentSource": NormalizePaymentSource(req.PaymentSource),
})
- return &CreateOrderResponse{OrderID: order.ID, Amount: order.Amount, PayAmount: payAmount, FeeRate: order.FeeRate, Status: OrderStatusPending, PaymentType: req.PaymentType, PayURL: pr.PayURL, QRCode: pr.QRCode, ClientSecret: pr.ClientSecret, ExpiresAt: order.ExpiresAt, PaymentMode: sel.PaymentMode}, nil
+ resultType := pr.ResultType
+ if resultType == "" {
+ resultType = payment.CreatePaymentResultOrderCreated
+ }
+ resp := buildCreateOrderResponse(order, req, payAmount, sel, pr, resultType)
+ resp.ResumeToken = resumeToken
+ return resp, nil
+}
+
+func buildProviderCreatePaymentRequest(req CreateOrderRequest, sel *payment.InstanceSelection, orderID, amount, subject string) payment.CreatePaymentRequest {
+ return payment.CreatePaymentRequest{
+ OrderID: orderID,
+ Amount: amount,
+ PaymentType: req.PaymentType,
+ Subject: subject,
+ ReturnURL: req.ReturnURL,
+ OpenID: strings.TrimSpace(req.OpenID),
+ ClientIP: req.ClientIP,
+ IsMobile: req.IsMobile,
+ InstanceSubMethods: selectedInstanceSupportedTypes(sel),
+ }
+}
+
+func selectedInstanceSupportedTypes(sel *payment.InstanceSelection) string {
+ if sel == nil {
+ return ""
+ }
+ return sel.SupportedTypes
}
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig) string {
@@ -247,6 +482,193 @@ func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limit
return "Sub2API " + amountStr + " CNY"
}
+func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) {
+ return s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, amount, payAmount, feeRate, nil)
+}
+
+func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponseForSelection(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64, sel *payment.InstanceSelection) (*CreateOrderResponse, error) {
+ if sel != nil && sel.ProviderKey != "" && sel.ProviderKey != payment.TypeWxpay {
+ return nil, nil
+ }
+ if strings.TrimSpace(req.OpenID) != "" || !req.IsWeChatBrowser || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
+ return nil, nil
+ }
+ return s.buildWeChatOAuthRequiredResponse(ctx, req, amount, payAmount, feeRate)
+}
+
+func (s *PaymentService) buildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) {
+ appID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := s.paymentResume().ensureSigningKey(); err != nil {
+ return nil, err
+ }
+
+ authorizeURL, err := buildWeChatPaymentOAuthStartURL(req, "snsapi_base")
+ if err != nil {
+ return nil, err
+ }
+
+ return &CreateOrderResponse{
+ Amount: amount,
+ PayAmount: payAmount,
+ FeeRate: feeRate,
+ ResultType: payment.CreatePaymentResultOAuthRequired,
+ PaymentType: req.PaymentType,
+ OAuth: &payment.WechatOAuthInfo{
+ AuthorizeURL: authorizeURL,
+ AppID: appID,
+ Scope: "snsapi_base",
+ RedirectURL: "/auth/wechat/payment/callback",
+ },
+ }, nil
+}
+
+func (s *PaymentService) validateSelectedCreateOrderInstance(ctx context.Context, req CreateOrderRequest, sel *payment.InstanceSelection) error {
+ if !requiresWeChatJSAPICompatibleSelection(req, sel) {
+ return nil
+ }
+ expectedAppID, _, err := s.getWeChatPaymentOAuthCredential(ctx)
+ if err != nil {
+ return err
+ }
+ selectedAppID := provider.ResolveWxpayJSAPIAppID(sel.Config)
+ if selectedAppID == "" || selectedAppID != expectedAppID {
+ return infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "selected payment instance is not compatible with the current WeChat OAuth app")
+ }
+ return nil
+}
+
+func requiresWeChatJSAPICompatibleSelection(req CreateOrderRequest, sel *payment.InstanceSelection) bool {
+ if sel == nil || sel.ProviderKey != payment.TypeWxpay || payment.GetBasePaymentType(req.PaymentType) != payment.TypeWxpay {
+ return false
+ }
+ return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != ""
+}
+
+func (s *PaymentService) getWeChatPaymentOAuthCredential(ctx context.Context) (string, string, error) {
+ if s == nil || s.configService == nil || s.configService.settingRepo == nil {
+ return "", "", infraerrors.ServiceUnavailable(
+ "WECHAT_PAYMENT_MP_NOT_CONFIGURED",
+ "wechat in-app payment requires a complete WeChat MP OAuth credential",
+ )
+ }
+ cfg, err := (&SettingService{settingRepo: s.configService.settingRepo}).GetWeChatConnectOAuthConfig(ctx)
+ appID := strings.TrimSpace(cfg.AppIDForMode("mp"))
+ appSecret := strings.TrimSpace(cfg.AppSecretForMode("mp"))
+ if err != nil || !cfg.SupportsMode("mp") || appID == "" || appSecret == "" {
+ return "", "", infraerrors.ServiceUnavailable(
+ "WECHAT_PAYMENT_MP_NOT_CONFIGURED",
+ "wechat in-app payment requires a complete WeChat MP OAuth credential",
+ )
+ }
+ return appID, appSecret, nil
+}
+
+func classifyCreatePaymentError(req CreateOrderRequest, providerKey string, err error) error {
+ if err == nil {
+ return nil
+ }
+ if providerKey == payment.TypeWxpay &&
+ payment.GetBasePaymentType(req.PaymentType) == payment.TypeWxpay &&
+ strings.Contains(err.Error(), "wxpay h5 payments are not authorized for this merchant") {
+ return infraerrors.ServiceUnavailable(
+ "WECHAT_H5_NOT_AUTHORIZED",
+ "wechat h5 payment is not available for this merchant",
+ ).WithMetadata(map[string]string{
+ "action": "open_in_wechat_or_scan_qr",
+ })
+ }
+ return infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error()))
+}
+
+func buildCreateOrderResponse(order *dbent.PaymentOrder, req CreateOrderRequest, payAmount float64, sel *payment.InstanceSelection, pr *payment.CreatePaymentResponse, resultType payment.CreatePaymentResultType) *CreateOrderResponse {
+ return &CreateOrderResponse{
+ OrderID: order.ID,
+ Amount: order.Amount,
+ PayAmount: payAmount,
+ FeeRate: order.FeeRate,
+ Status: OrderStatusPending,
+ ResultType: resultType,
+ PaymentType: req.PaymentType,
+ OutTradeNo: order.OutTradeNo,
+ PayURL: pr.PayURL,
+ QRCode: pr.QRCode,
+ ClientSecret: pr.ClientSecret,
+ OAuth: pr.OAuth,
+ JSAPI: pr.JSAPI,
+ JSAPIPayload: pr.JSAPI,
+ ExpiresAt: order.ExpiresAt,
+ PaymentMode: sel.PaymentMode,
+ }
+}
+
+func buildWeChatPaymentOAuthStartURL(req CreateOrderRequest, scope string) (string, error) {
+ u, err := url.Parse("/api/v1/auth/oauth/wechat/payment/start")
+ if err != nil {
+ return "", fmt.Errorf("build wechat payment oauth start url: %w", err)
+ }
+ q := u.Query()
+ q.Set("payment_type", strings.TrimSpace(req.PaymentType))
+ if req.Amount > 0 {
+ q.Set("amount", strconv.FormatFloat(req.Amount, 'f', -1, 64))
+ }
+ if orderType := strings.TrimSpace(req.OrderType); orderType != "" {
+ q.Set("order_type", orderType)
+ }
+ if req.PlanID > 0 {
+ q.Set("plan_id", strconv.FormatInt(req.PlanID, 10))
+ }
+ if scope = strings.TrimSpace(scope); scope != "" {
+ q.Set("scope", scope)
+ }
+ if redirectTo := paymentRedirectPathFromURL(req.SrcURL); redirectTo != "" {
+ q.Set("redirect", redirectTo)
+ }
+ u.RawQuery = q.Encode()
+ return u.String(), nil
+}
+
+func paymentRedirectPathFromURL(rawURL string) string {
+ rawURL = strings.TrimSpace(rawURL)
+ if rawURL == "" {
+ return "/purchase"
+ }
+ if strings.HasPrefix(rawURL, "/") && !strings.HasPrefix(rawURL, "//") {
+ return normalizePaymentRedirectPath(rawURL)
+ }
+ u, err := url.Parse(rawURL)
+ if err != nil {
+ return "/purchase"
+ }
+ path := strings.TrimSpace(u.EscapedPath())
+ if path == "" {
+ path = strings.TrimSpace(u.Path)
+ }
+ if path == "" || !strings.HasPrefix(path, "/") || strings.HasPrefix(path, "//") {
+ return "/purchase"
+ }
+ if strings.TrimSpace(u.RawQuery) != "" {
+ path += "?" + u.RawQuery
+ }
+ return normalizePaymentRedirectPath(path)
+}
+
+func normalizePaymentRedirectPath(path string) string {
+ path = strings.TrimSpace(path)
+ if path == "" {
+ return "/purchase"
+ }
+ if path == "/payment" {
+ return "/purchase"
+ }
+ if strings.HasPrefix(path, "/payment?") {
+ return "/purchase" + strings.TrimPrefix(path, "/payment")
+ }
+ return path
+}
+
// --- Order Queries ---
func (s *PaymentService) GetOrder(ctx context.Context, orderID, userID int64) (*dbent.PaymentOrder, error) {
diff --git a/backend/internal/service/payment_order_jsapi_test.go b/backend/internal/service/payment_order_jsapi_test.go
new file mode 100644
index 00000000..8c5e4fc0
--- /dev/null
+++ b/backend/internal/service/payment_order_jsapi_test.go
@@ -0,0 +1,98 @@
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+func TestUsesOfficialWxpayVisibleMethodDerivesFromEnabledProviderInstance(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("Official WeChat").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official wxpay instance: %v", err)
+ }
+
+ svc := &PaymentService{
+ configService: &PaymentConfigService{entClient: client},
+ }
+
+ if !svc.usesOfficialWxpayVisibleMethod(ctx) {
+ t.Fatal("expected official wxpay visible method to be detected from enabled provider instance")
+ }
+}
+
+func TestUsesOfficialWxpayVisibleMethodRespectsConfiguredSourceWhenMultipleProvidersEnabled(t *testing.T) {
+ tests := []struct {
+ name string
+ source string
+ wantOfficial bool
+ }{
+ {
+ name: "official source selected",
+ source: VisibleMethodSourceOfficialWechat,
+ wantOfficial: true,
+ },
+ {
+ name: "easypay source selected",
+ source: VisibleMethodSourceEasyPayWechat,
+ wantOfficial: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("Official WeChat").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official wxpay instance: %v", err)
+ }
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay WeChat").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ SetSortOrder(2).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay wxpay instance: %v", err)
+ }
+
+ svc := &PaymentService{
+ configService: &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ SettingPaymentVisibleMethodWxpaySource: tt.source,
+ },
+ },
+ },
+ }
+
+ if got := svc.usesOfficialWxpayVisibleMethod(ctx); got != tt.wantOfficial {
+ t.Fatalf("usesOfficialWxpayVisibleMethod() = %v, want %v", got, tt.wantOfficial)
+ }
+ })
+ }
+}
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index 80147180..b627ced4 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"strconv"
+ "strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -139,34 +140,123 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
if err != nil {
return ""
}
- // Use OutTradeNo as fallback when PaymentTradeNo is empty
- // (e.g. EasyPay popup mode where trade_no arrives only via notify callback)
- tradeNo := o.PaymentTradeNo
- if tradeNo == "" {
- tradeNo = o.OutTradeNo
+ queryRef := paymentOrderQueryReference(o, prov)
+ if queryRef == "" {
+ return ""
}
- resp, err := prov.QueryOrder(ctx, tradeNo)
+ resp, err := prov.QueryOrder(ctx, queryRef)
if err != nil {
slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
return ""
}
if resp.Status == payment.ProviderStatusPaid {
- if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
+ if !isValidProviderAmount(resp.Amount) {
+ s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", prov.ProviderKey(), map[string]any{
+ "expected": o.PayAmount,
+ "paid": resp.Amount,
+ "tradeNo": resp.TradeNo,
+ "queryRef": queryRef,
+ })
+ slog.Warn("query upstream returned invalid paid amount", "orderID", o.ID, "queryRef", queryRef, "paid", resp.Amount)
+ retriedResp, retryOK := requeryPaidOrderOnce(ctx, prov, queryRef)
+ if !retryOK {
+ return ""
+ }
+ resp = retriedResp
+ }
+ notificationTradeNo := o.PaymentTradeNo
+ if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) {
+ if _, updateErr := s.entClient.PaymentOrder.Update().
+ Where(paymentorder.IDEQ(o.ID)).
+ SetPaymentTradeNo(upstreamTradeNo).
+ Save(ctx); updateErr != nil {
+ slog.Error("persist upstream trade no during checkPaid failed", "orderID", o.ID, "tradeNo", upstreamTradeNo, "error", updateErr)
+ } else {
+ o.PaymentTradeNo = upstreamTradeNo
+ }
+ notificationTradeNo = upstreamTradeNo
+ }
+ if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: notificationTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess, Metadata: resp.Metadata}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}
return checkPaidResultAlreadyPaid
}
if cp, ok := prov.(payment.CancelableProvider); ok {
- _ = cp.CancelPayment(ctx, tradeNo)
+ _ = cp.CancelPayment(ctx, queryRef)
}
return ""
}
+func requeryPaidOrderOnce(ctx context.Context, prov payment.Provider, queryRef string) (*payment.QueryOrderResponse, bool) {
+ if prov == nil || strings.TrimSpace(queryRef) == "" {
+ return nil, false
+ }
+ resp, err := prov.QueryOrder(ctx, queryRef)
+ if err != nil {
+ slog.Warn("query upstream retry failed", "queryRef", queryRef, "error", err)
+ return nil, false
+ }
+ if resp == nil || resp.Status != payment.ProviderStatusPaid || !isValidProviderAmount(resp.Amount) {
+ return nil, false
+ }
+ return resp, true
+}
+
+func paymentOrderQueryReference(order *dbent.PaymentOrder, prov payment.Provider) string {
+ if order == nil {
+ return ""
+ }
+
+ providerKey := ""
+ if prov != nil {
+ providerKey = strings.TrimSpace(prov.ProviderKey())
+ }
+ if providerKey == "" {
+ if snapshot := psOrderProviderSnapshot(order); snapshot != nil {
+ providerKey = strings.TrimSpace(snapshot.ProviderKey)
+ }
+ }
+ if providerKey == "" {
+ providerKey = strings.TrimSpace(psStringValue(order.ProviderKey))
+ }
+ if providerKey == "" {
+ providerKey = strings.TrimSpace(order.PaymentType)
+ }
+
+ switch payment.GetBasePaymentType(providerKey) {
+ case payment.TypeAlipay, payment.TypeEasyPay, payment.TypeWxpay:
+ return strings.TrimSpace(order.OutTradeNo)
+ default:
+ if tradeNo := strings.TrimSpace(order.PaymentTradeNo); tradeNo != "" {
+ return tradeNo
+ }
+ return strings.TrimSpace(order.OutTradeNo)
+ }
+}
+
+func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, currentTradeNo string) bool {
+ upstreamTradeNo = strings.TrimSpace(upstreamTradeNo)
+ if upstreamTradeNo == "" {
+ return false
+ }
+ if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(currentTradeNo)) {
+ return false
+ }
+ if strings.EqualFold(upstreamTradeNo, strings.TrimSpace(queryRef)) {
+ return false
+ }
+ return true
+}
+
// VerifyOrderByOutTradeNo actively queries the upstream provider to check
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
+ outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
+ if err != nil {
+ return nil, err
+ }
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
@@ -190,25 +280,42 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
return o, nil
}
-// VerifyOrderPublic verifies payment status without user authentication.
-// Used by the payment result page when the user's session has expired.
+// VerifyOrderPublic returns the currently persisted public order state without
+// triggering any upstream reconciliation. Signed resume-token recovery is the
+// only public recovery path allowed to query upstream state.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
+ outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo)
+ if err != nil {
+ return nil, err
+ }
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
- if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
- result := s.checkPaid(ctx, o)
- if result == checkPaidResultAlreadyPaid {
- o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
- if err != nil {
- return nil, fmt.Errorf("reload order: %w", err)
- }
+ return o, nil
+}
+
+func normalizeOrderLookupOutTradeNo(raw string) (string, error) {
+ outTradeNo := strings.TrimSpace(raw)
+ if outTradeNo == "" {
+ return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required")
+ }
+ if len(outTradeNo) > 64 {
+ return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
+ }
+ for _, ch := range outTradeNo {
+ switch {
+ case ch >= 'a' && ch <= 'z':
+ case ch >= 'A' && ch <= 'Z':
+ case ch >= '0' && ch <= '9':
+ case ch == '_' || ch == '-':
+ default:
+ return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid")
}
}
- return o, nil
+ return outTradeNo, nil
}
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
@@ -236,22 +343,79 @@ func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error)
// getOrderProvider creates a provider using the order's original instance config.
// Falls back to registry lookup if instance ID is missing (legacy orders).
func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
- if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" {
- instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
- if err == nil {
- cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID)
- if err == nil {
- providerKey := s.registry.GetProviderKey(o.PaymentType)
- if providerKey == "" {
- providerKey = o.PaymentType
- }
- p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg)
- if err == nil {
- return p, nil
- }
- }
- }
+ inst, err := s.getOrderProviderInstance(ctx, o)
+ if err != nil {
+ return nil, fmt.Errorf("load order provider instance: %w", err)
+ }
+ if inst != nil {
+ return s.createProviderFromInstance(ctx, inst)
+ }
+ if !paymentOrderAllowsRegistryFallback(o) {
+ return nil, fmt.Errorf("order %d provider instance is unresolved", o.ID)
+ }
+ providerKey := paymentOrderFallbackProviderKey(s.registry, o)
+ if providerKey == "" {
+ return nil, fmt.Errorf("order %d provider fallback key is missing", o.ID)
+ }
+ if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
+ return nil, fmt.Errorf("order %d provider fallback is ambiguous for %s", o.ID, providerKey)
}
s.EnsureProviders(ctx)
return s.registry.GetProvider(o.PaymentType)
}
+
+func paymentOrderAllowsRegistryFallback(order *dbent.PaymentOrder) bool {
+ if order == nil {
+ return false
+ }
+ if psOrderProviderSnapshot(order) != nil {
+ return false
+ }
+ if strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != "" {
+ return false
+ }
+ if strings.TrimSpace(psStringValue(order.ProviderKey)) != "" {
+ return false
+ }
+ return true
+}
+
+func paymentOrderFallbackProviderKey(registry *payment.Registry, order *dbent.PaymentOrder) string {
+ if order == nil {
+ return ""
+ }
+ if registry != nil {
+ if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(order.PaymentType))); key != "" {
+ return key
+ }
+ }
+ return strings.TrimSpace(payment.GetBasePaymentType(strings.TrimSpace(order.PaymentType)))
+}
+
+func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) {
+ if inst == nil {
+ return nil, fmt.Errorf("payment provider instance is missing")
+ }
+
+ cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID))
+ if err != nil {
+ return nil, fmt.Errorf("load provider instance config: %w", err)
+ }
+ if inst.PaymentMode != "" {
+ cfg["paymentMode"] = inst.PaymentMode
+ }
+
+ instID := strconv.FormatInt(int64(inst.ID), 10)
+ prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg)
+ if err != nil {
+ return nil, fmt.Errorf("create provider from instance: %w", err)
+ }
+ return prov, nil
+}
+
+func psStringValue(value *string) string {
+ if value == nil {
+ return ""
+ }
+ return *value
+}
diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go
new file mode 100644
index 00000000..8dfd2e7e
--- /dev/null
+++ b/backend/internal/service/payment_order_lifecycle_test.go
@@ -0,0 +1,575 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type paymentOrderLifecycleQueryProvider struct {
+ lastQueryTradeNo string
+ queryCalls int
+ responses []*payment.QueryOrderResponse
+ resp *payment.QueryOrderResponse
+}
+
+type paymentOrderLifecycleRedeemRepo struct {
+ codesByCode map[string]*RedeemCode
+ useCalls []struct {
+ id int64
+ userID int64
+ }
+}
+
+func (p *paymentOrderLifecycleQueryProvider) Name() string {
+ return "payment-order-lifecycle-query-provider"
+}
+
+func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string { return payment.TypeAlipay }
+
+func (p *paymentOrderLifecycleQueryProvider) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.TypeAlipay}
+}
+
+func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentOrderLifecycleQueryProvider) QueryOrder(_ context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
+ p.lastQueryTradeNo = tradeNo
+ p.queryCalls++
+ if len(p.responses) > 0 {
+ resp := p.responses[0]
+ if len(p.responses) > 1 {
+ p.responses = p.responses[1:]
+ }
+ return resp, nil
+ }
+ return p.resp, nil
+}
+
+func (p *paymentOrderLifecycleQueryProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentOrderLifecycleQueryProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Create(context.Context, *RedeemCode) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) CreateBatch(context.Context, []RedeemCode) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) GetByID(_ context.Context, id int64) (*RedeemCode, error) {
+ for _, code := range r.codesByCode {
+ if code.ID != id {
+ continue
+ }
+ cloned := *code
+ return &cloned, nil
+ }
+ return nil, ErrRedeemCodeNotFound
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
+ redeemCode, ok := r.codesByCode[code]
+ if !ok {
+ return nil, ErrRedeemCodeNotFound
+ }
+ cloned := *redeemCode
+ return &cloned, nil
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Update(context.Context, *RedeemCode) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Delete(context.Context, int64) error {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) Use(_ context.Context, id, userID int64) error {
+ for code, redeemCode := range r.codesByCode {
+ if redeemCode.ID != id {
+ continue
+ }
+ now := time.Now().UTC()
+ redeemCode.Status = StatusUsed
+ redeemCode.UsedBy = &userID
+ redeemCode.UsedAt = &now
+ r.codesByCode[code] = redeemCode
+ r.useCalls = append(r.useCalls, struct {
+ id int64
+ userID int64
+ }{id: id, userID: userID})
+ return nil
+ }
+ return ErrRedeemCodeNotFound
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected call")
+}
+
+func (r *paymentOrderLifecycleRedeemRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
+ panic("unexpected call")
+}
+
+func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-UPSTREAM-TRADE-NO").
+ SetOutTradeNo("sub2_checkpaid_trade_no_missing").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
+ require.Equal(t, user.ID, id)
+ if userRepo.getByIDUser != nil {
+ userRepo.getByIDUser.Balance += amount
+ }
+ return nil
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ resp: &payment.QueryOrderResponse{
+ TradeNo: "upstream-trade-123",
+ Status: payment.ProviderStatusPaid,
+ Amount: 88,
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
+ require.Equal(t, OrderStatusCompleted, got.Status)
+ require.Equal(t, "upstream-trade-123", got.PaymentTradeNo)
+
+ reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
+ require.NoError(t, err)
+ require.Equal(t, OrderStatusCompleted, reloaded.Status)
+ require.Equal(t, "upstream-trade-123", reloaded.PaymentTradeNo)
+
+ require.Equal(t, 88.0, userRepo.getByIDUser.Balance)
+ require.Len(t, redeemRepo.useCalls, 1)
+ require.Equal(t, int64(1), redeemRepo.useCalls[0].id)
+ require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
+}
+
+func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid-retry@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-retry-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-UPSTREAM-RETRY").
+ SetOutTradeNo("sub2_checkpaid_retry_zero_amount").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
+ require.Equal(t, user.ID, id)
+ if userRepo.getByIDUser != nil {
+ userRepo.getByIDUser.Balance += amount
+ }
+ return nil
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ responses: []*payment.QueryOrderResponse{
+ {
+ TradeNo: "upstream-trade-zero",
+ Status: payment.ProviderStatusPaid,
+ Amount: 0,
+ },
+ {
+ TradeNo: "upstream-trade-retry",
+ Status: payment.ProviderStatusPaid,
+ Amount: 88,
+ },
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 2, provider.queryCalls)
+ require.Equal(t, OrderStatusCompleted, got.Status)
+ require.Equal(t, "upstream-trade-retry", got.PaymentTradeNo)
+}
+
+func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid-zero-amount@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-zero-amount-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-ZERO-AMOUNT").
+ SetOutTradeNo("sub2_checkpaid_zero_amount").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ resp: &payment.QueryOrderResponse{
+ TradeNo: "upstream-trade-zero",
+ Status: payment.ProviderStatusPaid,
+ Amount: 0,
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
+ require.Equal(t, OrderStatusPending, got.Status)
+ require.Empty(t, got.PaymentTradeNo)
+
+ reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
+ require.NoError(t, err)
+ require.Equal(t, OrderStatusPending, reloaded.Status)
+ require.Empty(t, reloaded.PaymentTradeNo)
+
+ require.Equal(t, 0.0, userRepo.getByIDUser.Balance)
+ require.Empty(t, redeemRepo.useCalls)
+}
+
+func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid-existing-trade@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-existing-trade-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-EXISTING-TRADE-NO").
+ SetOutTradeNo("sub2_checkpaid_use_out_trade_no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("upstream-trade-existing").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
+ require.Equal(t, user.ID, id)
+ if userRepo.getByIDUser != nil {
+ userRepo.getByIDUser.Balance += amount
+ }
+ return nil
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ resp: &payment.QueryOrderResponse{
+ TradeNo: "upstream-trade-existing",
+ Status: payment.ProviderStatusPaid,
+ Amount: 88,
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
+ require.Equal(t, "upstream-trade-existing", got.PaymentTradeNo)
+}
+
+func TestPaymentOrderAllowsRegistryFallbackOnlyForLegacyOrdersWithoutPinnedProviderState(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ }))
+
+ instanceID := "12"
+ require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderInstanceID: &instanceID,
+ }))
+
+ require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipay,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": "12",
+ },
+ }))
+}
+
+func TestPaymentOrderQueryReferenceUsesOutTradeNoForOfficialProviders(t *testing.T) {
+ t.Parallel()
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ OutTradeNo: "sub2_out_trade_no",
+ PaymentTradeNo: "wx-transaction-id",
+ }
+
+ require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, &paymentOrderLifecycleQueryProvider{}))
+ require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, paymentFulfillmentTestProvider{
+ key: payment.TypeWxpay,
+ }))
+}
+
+func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:payment_order_lifecycle?mode=memory&cache=shared&_fk=1")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
diff --git a/backend/internal/service/payment_order_provider_snapshot.go b/backend/internal/service/payment_order_provider_snapshot.go
new file mode 100644
index 00000000..bb60f9e2
--- /dev/null
+++ b/backend/internal/service/payment_order_provider_snapshot.go
@@ -0,0 +1,205 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+type paymentOrderProviderSnapshot struct {
+ SchemaVersion int
+ ProviderInstanceID string
+ ProviderKey string
+ PaymentMode string
+ MerchantAppID string
+ MerchantID string
+ Currency string
+}
+
+func psOrderProviderSnapshot(order *dbent.PaymentOrder) *paymentOrderProviderSnapshot {
+ if order == nil || len(order.ProviderSnapshot) == 0 {
+ return nil
+ }
+
+ snapshot := &paymentOrderProviderSnapshot{
+ SchemaVersion: psSnapshotIntValue(order.ProviderSnapshot["schema_version"]),
+ ProviderInstanceID: psSnapshotStringValue(order.ProviderSnapshot["provider_instance_id"]),
+ ProviderKey: psSnapshotStringValue(order.ProviderSnapshot["provider_key"]),
+ PaymentMode: psSnapshotStringValue(order.ProviderSnapshot["payment_mode"]),
+ MerchantAppID: psSnapshotStringValue(order.ProviderSnapshot["merchant_app_id"]),
+ MerchantID: psSnapshotStringValue(order.ProviderSnapshot["merchant_id"]),
+ Currency: psSnapshotStringValue(order.ProviderSnapshot["currency"]),
+ }
+ if snapshot.SchemaVersion == 0 &&
+ snapshot.ProviderInstanceID == "" &&
+ snapshot.ProviderKey == "" &&
+ snapshot.PaymentMode == "" &&
+ snapshot.MerchantAppID == "" &&
+ snapshot.MerchantID == "" &&
+ snapshot.Currency == "" {
+ return nil
+ }
+ return snapshot
+}
+
+func psSnapshotStringValue(value any) string {
+ switch typed := value.(type) {
+ case string:
+ return strings.TrimSpace(typed)
+ default:
+ return ""
+ }
+}
+
+func psSnapshotIntValue(value any) int {
+ switch typed := value.(type) {
+ case int:
+ return typed
+ case int32:
+ return int(typed)
+ case int64:
+ return int(typed)
+ case float32:
+ return int(typed)
+ case float64:
+ return int(typed)
+ case string:
+ n, err := strconv.Atoi(strings.TrimSpace(typed))
+ if err == nil {
+ return n
+ }
+ }
+ return 0
+}
+
+func (s *PaymentService) resolveSnapshotOrderProviderInstance(ctx context.Context, order *dbent.PaymentOrder, snapshot *paymentOrderProviderSnapshot) (*dbent.PaymentProviderInstance, error) {
+ if s == nil || s.entClient == nil || order == nil || snapshot == nil {
+ return nil, nil
+ }
+
+ snapshotInstanceID := strings.TrimSpace(snapshot.ProviderInstanceID)
+ columnInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
+ if snapshotInstanceID == "" {
+ snapshotInstanceID = columnInstanceID
+ }
+ if snapshotInstanceID == "" {
+ return nil, fmt.Errorf("order %d provider snapshot is missing provider_instance_id", order.ID)
+ }
+ if columnInstanceID != "" && snapshot.ProviderInstanceID != "" && !strings.EqualFold(columnInstanceID, snapshot.ProviderInstanceID) {
+ return nil, fmt.Errorf("order %d provider snapshot instance mismatch: snapshot=%s order=%s", order.ID, snapshot.ProviderInstanceID, columnInstanceID)
+ }
+
+ instID, err := strconv.ParseInt(snapshotInstanceID, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("order %d provider snapshot instance id is invalid: %s", order.ID, snapshotInstanceID)
+ }
+
+ inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, fmt.Errorf("order %d provider snapshot instance %s is missing", order.ID, snapshotInstanceID)
+ }
+ return nil, err
+ }
+
+ if snapshot.ProviderKey != "" && !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), snapshot.ProviderKey) {
+ return nil, fmt.Errorf("order %d provider snapshot key mismatch: snapshot=%s instance=%s", order.ID, snapshot.ProviderKey, inst.ProviderKey)
+ }
+
+ return inst, nil
+}
+
+func expectedNotificationProviderKeyForOrder(registry *payment.Registry, order *dbent.PaymentOrder, instanceProviderKey string) string {
+ if order == nil {
+ return strings.TrimSpace(instanceProviderKey)
+ }
+
+ orderProviderKey := psStringValue(order.ProviderKey)
+ if snapshot := psOrderProviderSnapshot(order); snapshot != nil && snapshot.ProviderKey != "" {
+ orderProviderKey = snapshot.ProviderKey
+ }
+
+ return expectedNotificationProviderKey(registry, order.PaymentType, orderProviderKey, instanceProviderKey)
+}
+
+func validateProviderSnapshotMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
+ if order == nil || len(metadata) == 0 {
+ return nil
+ }
+
+ snapshot := psOrderProviderSnapshot(order)
+ if snapshot == nil {
+ return nil
+ }
+
+ switch strings.TrimSpace(providerKey) {
+ case payment.TypeWxpay:
+ if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
+ actual := strings.TrimSpace(metadata["appid"])
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing appid")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay appid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
+ actual := strings.TrimSpace(metadata["mchid"])
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing mchid")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay mchid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if expected := strings.TrimSpace(snapshot.Currency); expected != "" {
+ actual := strings.ToUpper(strings.TrimSpace(metadata["currency"]))
+ if actual == "" {
+ return fmt.Errorf("wxpay notification missing currency")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("wxpay currency mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ if actual := strings.TrimSpace(metadata["trade_state"]); actual != "" && !strings.EqualFold(actual, "SUCCESS") {
+ return fmt.Errorf("wxpay trade_state mismatch: expected SUCCESS, got %s", actual)
+ }
+ case payment.TypeAlipay:
+ if expected := strings.TrimSpace(snapshot.MerchantAppID); expected != "" {
+ actual := strings.TrimSpace(metadata["app_id"])
+ if actual == "" {
+ return fmt.Errorf("alipay app_id missing")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("alipay app_id mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ case payment.TypeEasyPay:
+ if expected := strings.TrimSpace(snapshot.MerchantID); expected != "" {
+ actual := strings.TrimSpace(metadata["pid"])
+ if actual == "" {
+ return fmt.Errorf("easypay pid missing")
+ }
+ if !strings.EqualFold(expected, actual) {
+ return fmt.Errorf("easypay pid mismatch: expected %s, got %s", expected, actual)
+ }
+ }
+ }
+
+ return nil
+}
+
+func providerMerchantIdentityMetadata(prov payment.Provider) map[string]string {
+ if prov == nil {
+ return nil
+ }
+ reporter, ok := prov.(payment.MerchantIdentityProvider)
+ if !ok {
+ return nil
+ }
+ return reporter.MerchantIdentityMetadata()
+}
diff --git a/backend/internal/service/payment_order_provider_snapshot_test.go b/backend/internal/service/payment_order_provider_snapshot_test.go
new file mode 100644
index 00000000..efa013b5
--- /dev/null
+++ b/backend/internal/service/payment_order_provider_snapshot_test.go
@@ -0,0 +1,172 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "strconv"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBuildPaymentOrderProviderSnapshot_ExcludesSensitiveConfig(t *testing.T) {
+ t.Parallel()
+
+ sel := &payment.InstanceSelection{
+ InstanceID: "12",
+ ProviderKey: payment.TypeWxpay,
+ SupportedTypes: "wxpay,wxpay_direct",
+ PaymentMode: "popup",
+ Config: map[string]string{
+ "privateKey": "secret",
+ "apiV3Key": "secret-v3",
+ "appId": "wx-app-id",
+ },
+ }
+
+ snapshot := buildPaymentOrderProviderSnapshot(sel, CreateOrderRequest{})
+ require.Equal(t, map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": "12",
+ "provider_key": payment.TypeWxpay,
+ "payment_mode": "popup",
+ "merchant_app_id": "wx-app-id",
+ "currency": "CNY",
+ }, snapshot)
+ require.NotContains(t, snapshot, "config")
+ require.NotContains(t, snapshot, "privateKey")
+ require.NotContains(t, snapshot, "apiV3Key")
+ require.NotContains(t, snapshot, "supported_types")
+ require.NotContains(t, snapshot, "instance_name")
+ require.NotContains(t, snapshot, "merchant_id")
+}
+
+func TestCreateOrderInTx_WritesProviderSnapshot(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("snapshot@example.com").
+ SetPasswordHash("hash").
+ SetUsername("snapshot-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ instance, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Primary Alipay").
+ SetConfig(`{"secretKey":"do-not-copy"}`).
+ SetSupportedTypes("alipay,alipay_direct").
+ SetPaymentMode("redirect").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{entClient: client}
+ order, err := svc.createOrderInTx(
+ ctx,
+ CreateOrderRequest{
+ UserID: user.ID,
+ PaymentType: payment.TypeAlipay,
+ OrderType: payment.OrderTypeBalance,
+ ClientIP: "127.0.0.1",
+ SrcHost: "app.example.com",
+ },
+ &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ },
+ nil,
+ &PaymentConfig{
+ MaxPendingOrders: 3,
+ OrderTimeoutMin: 30,
+ },
+ 88,
+ 88,
+ 0,
+ 88,
+ &payment.InstanceSelection{
+ InstanceID: strconv.FormatInt(instance.ID, 10),
+ ProviderKey: payment.TypeAlipay,
+ SupportedTypes: "alipay,alipay_direct",
+ PaymentMode: "redirect",
+ Config: map[string]string{
+ "secretKey": "do-not-copy",
+ },
+ },
+ )
+ require.NoError(t, err)
+ require.Equal(t, strconv.FormatInt(instance.ID, 10), valueOrEmpty(order.ProviderInstanceID))
+ require.Equal(t, payment.TypeAlipay, valueOrEmpty(order.ProviderKey))
+ require.Equal(t, float64(2), order.ProviderSnapshot["schema_version"])
+ require.Equal(t, strconv.FormatInt(instance.ID, 10), order.ProviderSnapshot["provider_instance_id"])
+ require.Equal(t, payment.TypeAlipay, order.ProviderSnapshot["provider_key"])
+ require.Equal(t, "redirect", order.ProviderSnapshot["payment_mode"])
+ require.NotContains(t, order.ProviderSnapshot, "config")
+ require.NotContains(t, order.ProviderSnapshot, "secretKey")
+ require.NotContains(t, order.ProviderSnapshot, "supported_types")
+ require.NotContains(t, order.ProviderSnapshot, "instance_name")
+}
+
+func TestBuildPaymentOrderProviderSnapshot_UsesWxpayJSAPIAppIDForOpenIDOrders(t *testing.T) {
+ t.Parallel()
+
+ snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
+ InstanceID: "88",
+ ProviderKey: payment.TypeWxpay,
+ Config: map[string]string{
+ "appId": "wx-open-app",
+ "mpAppId": "wx-mp-app",
+ "mchId": "mch-88",
+ },
+ PaymentMode: "jsapi",
+ }, CreateOrderRequest{OpenID: "openid-123"})
+
+ require.Equal(t, "wx-mp-app", snapshot["merchant_app_id"])
+ require.Equal(t, "mch-88", snapshot["merchant_id"])
+ require.Equal(t, "CNY", snapshot["currency"])
+}
+
+func TestBuildPaymentOrderProviderSnapshot_IncludesAlipayMerchantIdentity(t *testing.T) {
+ t.Parallel()
+
+ snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
+ InstanceID: "21",
+ ProviderKey: payment.TypeAlipay,
+ Config: map[string]string{
+ "appId": "alipay-app-21",
+ "privateKey": "secret",
+ },
+ PaymentMode: "redirect",
+ }, CreateOrderRequest{})
+
+ require.Equal(t, "alipay-app-21", snapshot["merchant_app_id"])
+ require.NotContains(t, snapshot, "privateKey")
+}
+
+func TestBuildPaymentOrderProviderSnapshot_IncludesEasyPayMerchantIdentity(t *testing.T) {
+ t.Parallel()
+
+ snapshot := buildPaymentOrderProviderSnapshot(&payment.InstanceSelection{
+ InstanceID: "66",
+ ProviderKey: payment.TypeEasyPay,
+ Config: map[string]string{
+ "pid": "easypay-merchant-66",
+ "pkey": "secret",
+ },
+ PaymentMode: "popup",
+ }, CreateOrderRequest{PaymentType: payment.TypeAlipay})
+
+ require.Equal(t, "easypay-merchant-66", snapshot["merchant_id"])
+ require.NotContains(t, snapshot, "pkey")
+}
+
+func valueOrEmpty(v *string) string {
+ if v == nil {
+ return ""
+ }
+ return *v
+}
diff --git a/backend/internal/service/payment_order_result_test.go b/backend/internal/service/payment_order_result_test.go
new file mode 100644
index 00000000..2d7412e0
--- /dev/null
+++ b/backend/internal/service/payment_order_result_test.go
@@ -0,0 +1,276 @@
+package service
+
+import (
+ "context"
+ "strings"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+func TestBuildCreateOrderResponseDefaultsToOrderCreated(t *testing.T) {
+ t.Parallel()
+
+ expiresAt := time.Date(2026, 4, 16, 12, 0, 0, 0, time.UTC)
+ resp := buildCreateOrderResponse(
+ &dbent.PaymentOrder{
+ ID: 42,
+ Amount: 12.34,
+ FeeRate: 0.03,
+ ExpiresAt: expiresAt,
+ OutTradeNo: "sub2_42",
+ },
+ CreateOrderRequest{PaymentType: payment.TypeWxpay},
+ 12.71,
+ &payment.InstanceSelection{PaymentMode: "qrcode"},
+ &payment.CreatePaymentResponse{
+ TradeNo: "sub2_42",
+ QRCode: "weixin://wxpay/bizpayurl?pr=test",
+ },
+ payment.CreatePaymentResultOrderCreated,
+ )
+
+ if resp.ResultType != payment.CreatePaymentResultOrderCreated {
+ t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOrderCreated)
+ }
+ if resp.OutTradeNo != "sub2_42" {
+ t.Fatalf("out_trade_no = %q, want %q", resp.OutTradeNo, "sub2_42")
+ }
+ if resp.QRCode != "weixin://wxpay/bizpayurl?pr=test" {
+ t.Fatalf("qr_code = %q, want %q", resp.QRCode, "weixin://wxpay/bizpayurl?pr=test")
+ }
+ if resp.JSAPI != nil || resp.JSAPIPayload != nil {
+ t.Fatal("order_created response should not include jsapi payload")
+ }
+ if !resp.ExpiresAt.Equal(expiresAt) {
+ t.Fatalf("expires_at = %v, want %v", resp.ExpiresAt, expiresAt)
+ }
+}
+
+func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) {
+ t.Parallel()
+
+ jsapiPayload := &payment.WechatJSAPIPayload{
+ AppID: "wx123",
+ TimeStamp: "1712345678",
+ NonceStr: "nonce-123",
+ Package: "prepay_id=wx123",
+ SignType: "RSA",
+ PaySign: "signed-payload",
+ }
+ resp := buildCreateOrderResponse(
+ &dbent.PaymentOrder{
+ ID: 88,
+ Amount: 66.88,
+ FeeRate: 0.01,
+ ExpiresAt: time.Date(2026, 4, 16, 13, 0, 0, 0, time.UTC),
+ OutTradeNo: "sub2_88",
+ },
+ CreateOrderRequest{PaymentType: payment.TypeWxpay},
+ 67.55,
+ &payment.InstanceSelection{PaymentMode: "popup"},
+ &payment.CreatePaymentResponse{
+ TradeNo: "sub2_88",
+ ResultType: payment.CreatePaymentResultJSAPIReady,
+ JSAPI: jsapiPayload,
+ },
+ payment.CreatePaymentResultJSAPIReady,
+ )
+
+ if resp.ResultType != payment.CreatePaymentResultJSAPIReady {
+ t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady)
+ }
+ if resp.JSAPI == nil || resp.JSAPIPayload == nil {
+ t.Fatal("expected jsapi payload aliases to be populated")
+ }
+ if resp.JSAPI != jsapiPayload || resp.JSAPIPayload != jsapiPayload {
+ t.Fatal("expected jsapi aliases to preserve the original pointer")
+ }
+}
+
+func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
+
+ svc := newWeChatPaymentOAuthTestService(map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx123456",
+ SettingKeyWeChatConnectAppSecret: "wechat-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
+
+ resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
+ Amount: 12.5,
+ PaymentType: payment.TypeWxpay,
+ IsWeChatBrowser: true,
+ SrcURL: "https://merchant.example/payment?from=wechat",
+ OrderType: payment.OrderTypeBalance,
+ }, 12.5, 12.88, 0.03)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp == nil {
+ t.Fatal("expected oauth_required response, got nil")
+ }
+ if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
+ t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
+ }
+ if resp.OAuth == nil {
+ t.Fatal("expected oauth payload, got nil")
+ }
+ if resp.OAuth.AppID != "wx123456" {
+ t.Fatalf("appid = %q, want %q", resp.OAuth.AppID, "wx123456")
+ }
+ if resp.OAuth.Scope != "snsapi_base" {
+ t.Fatalf("scope = %q, want %q", resp.OAuth.Scope, "snsapi_base")
+ }
+ if resp.OAuth.RedirectURL != "/auth/wechat/payment/callback" {
+ t.Fatalf("redirect_url = %q, want %q", resp.OAuth.RedirectURL, "/auth/wechat/payment/callback")
+ }
+ if resp.OAuth.AuthorizeURL != "/api/v1/auth/oauth/wechat/payment/start?amount=12.5&order_type=balance&payment_type=wxpay&redirect=%2Fpurchase%3Ffrom%3Dwechat&scope=snsapi_base" {
+ t.Fatalf("authorize_url = %q", resp.OAuth.AuthorizeURL)
+ }
+}
+
+func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testing.T) {
+ t.Parallel()
+
+ svc := newWeChatPaymentOAuthTestService(nil)
+
+ resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
+ Amount: 12.5,
+ PaymentType: payment.TypeWxpay,
+ IsWeChatBrowser: true,
+ SrcURL: "https://merchant.example/payment?from=wechat",
+ OrderType: payment.OrderTypeBalance,
+ }, 12.5, 12.88, 0.03)
+ if resp != nil {
+ t.Fatalf("expected nil response, got %+v", resp)
+ }
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+
+ appErr := infraerrors.FromError(err)
+ if appErr.Reason != "WECHAT_PAYMENT_MP_NOT_CONFIGURED" {
+ t.Fatalf("reason = %q, want %q", appErr.Reason, "WECHAT_PAYMENT_MP_NOT_CONFIGURED")
+ }
+}
+
+func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testing.T) {
+ t.Parallel()
+
+ svc := &PaymentService{
+ configService: &PaymentConfigService{
+ settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx123456",
+ SettingKeyWeChatConnectAppSecret: "wechat-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ }},
+ // Intentionally missing payment resume signing key.
+ encryptionKey: nil,
+ },
+ }
+
+ resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
+ Amount: 12.5,
+ PaymentType: payment.TypeWxpay,
+ IsWeChatBrowser: true,
+ SrcURL: "https://merchant.example/payment?from=wechat",
+ OrderType: payment.OrderTypeBalance,
+ }, 12.5, 12.88, 0.03)
+ if resp != nil {
+ t.Fatalf("expected nil response, got %+v", resp)
+ }
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+
+ appErr := infraerrors.FromError(err)
+ if appErr.Reason != "PAYMENT_RESUME_NOT_CONFIGURED" {
+ t.Fatalf("reason = %q, want %q", appErr.Reason, "PAYMENT_RESUME_NOT_CONFIGURED")
+ }
+}
+
+func TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey(t *testing.T) {
+ svc := &PaymentService{
+ configService: &PaymentConfigService{
+ settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx123456",
+ SettingKeyWeChatConnectAppSecret: "wechat-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ }},
+ // Legacy stable signing key remains available for no-config upgrade compatibility.
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ },
+ }
+
+ resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
+ Amount: 12.5,
+ PaymentType: payment.TypeWxpay,
+ IsWeChatBrowser: true,
+ SrcURL: "https://merchant.example/payment?from=wechat",
+ OrderType: payment.OrderTypeBalance,
+ }, 12.5, 12.88, 0.03)
+ if err != nil {
+ t.Fatalf("expected nil error, got %v", err)
+ }
+ if resp == nil {
+ t.Fatal("expected oauth-required response, got nil")
+ }
+ if resp.ResultType != payment.CreatePaymentResultOAuthRequired {
+ t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired)
+ }
+ if resp.OAuth == nil || strings.TrimSpace(resp.OAuth.AuthorizeURL) == "" {
+ t.Fatalf("expected oauth redirect payload, got %+v", resp.OAuth)
+ }
+}
+
+func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
+ svc := newWeChatPaymentOAuthTestService(map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx123456",
+ SettingKeyWeChatConnectAppSecret: "wechat-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
+
+ resp, err := svc.maybeBuildWeChatOAuthRequiredResponseForSelection(context.Background(), CreateOrderRequest{
+ Amount: 12.5,
+ PaymentType: payment.TypeWxpay,
+ IsWeChatBrowser: true,
+ OrderType: payment.OrderTypeBalance,
+ }, 12.5, 12.88, 0.03, &payment.InstanceSelection{
+ ProviderKey: payment.TypeEasyPay,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if resp != nil {
+ t.Fatalf("expected nil response, got %+v", resp)
+ }
+}
+
+func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService {
+ return &PaymentService{
+ configService: &PaymentConfigService{
+ settingRepo: &paymentConfigSettingRepoStub{values: values},
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ },
+ }
+}
diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go
index c5bda763..7521878c 100644
--- a/backend/internal/service/payment_refund.go
+++ b/backend/internal/service/payment_refund.go
@@ -12,6 +12,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
@@ -19,18 +20,133 @@ import (
// --- Refund Flow ---
// getOrderProviderInstance looks up the provider instance that processed this order.
-// Returns nil, nil for legacy orders without provider_instance_id.
+// For legacy orders without provider_instance_id, it resolves only when the
+// historical instance is uniquely identifiable from the stored order fields.
func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
- if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" {
+ if s == nil || s.entClient == nil || o == nil {
return nil, nil
}
- instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
+
+ if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
+ return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
+ }
+
+ instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
+ if instIDStr == "" {
+ return s.resolveUniqueLegacyOrderProviderInstance(ctx, o)
+ }
+
+ instID, err := strconv.ParseInt(instIDStr, 10, 64)
if err != nil {
return nil, nil
}
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
}
+// getRefundOrderProviderInstance resolves the provider instance for refund paths.
+// Refunds must be pinned to an explicit historical binding, so legacy
+// "best-effort" provider guessing is intentionally not allowed here.
+func (s *PaymentService) getRefundOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
+ if s == nil || s.entClient == nil || o == nil {
+ return nil, nil
+ }
+
+ if snapshot := psOrderProviderSnapshot(o); snapshot != nil {
+ return s.resolveSnapshotOrderProviderInstance(ctx, o, snapshot)
+ }
+
+ instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID))
+ if instIDStr == "" {
+ return nil, nil
+ }
+
+ instID, err := strconv.ParseInt(instIDStr, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("order %d refund provider instance id is invalid: %s", o.ID, instIDStr)
+ }
+ inst, err := s.entClient.PaymentProviderInstance.Get(ctx, instID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, fmt.Errorf("order %d refund provider instance %s is missing", o.ID, instIDStr)
+ }
+ return nil, err
+ }
+ return inst, nil
+}
+
+func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
+ paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType))
+ providerKey := strings.TrimSpace(psStringValue(o.ProviderKey))
+ if providerKey != "" {
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(paymentproviderinstance.ProviderKeyEQ(providerKey)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ matched := psFilterLegacyOrderProviderInstances(paymentType, instances)
+ if len(matched) == 1 {
+ return matched[0], nil
+ }
+ return nil, nil
+ }
+
+ if paymentType == "" {
+ return nil, nil
+ }
+
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ matched := psFilterLegacyOrderProviderInstances(paymentType, instances)
+ if len(matched) == 1 {
+ return matched[0], nil
+ }
+ return nil, nil
+}
+
+func psFilterLegacyOrderProviderInstances(orderPaymentType string, instances []*dbent.PaymentProviderInstance) []*dbent.PaymentProviderInstance {
+ if len(instances) == 0 {
+ return nil
+ }
+ if strings.TrimSpace(orderPaymentType) == "" {
+ return instances
+ }
+ var matched []*dbent.PaymentProviderInstance
+ for _, inst := range instances {
+ if psLegacyOrderMatchesInstance(orderPaymentType, inst) {
+ matched = append(matched, inst)
+ }
+ }
+ return matched
+}
+
+func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool {
+ if inst == nil {
+ return false
+ }
+
+ baseType := payment.GetBasePaymentType(strings.TrimSpace(orderPaymentType))
+ instanceProviderKey := strings.TrimSpace(inst.ProviderKey)
+ if baseType == "" {
+ return false
+ }
+
+ if baseType == payment.TypeStripe {
+ return instanceProviderKey == payment.TypeStripe
+ }
+ if instanceProviderKey == payment.TypeStripe {
+ return false
+ }
+ if instanceProviderKey == baseType {
+ return true
+ }
+ return payment.InstanceSupportsType(inst.SupportedTypes, baseType)
+}
+
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil {
@@ -72,7 +188,7 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
}
// Check provider instance allows user refund
- inst, err := s.getOrderProviderInstance(ctx, o)
+ inst, err := s.getRefundOrderProviderInstance(ctx, o)
if err != nil || inst == nil {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order")
}
@@ -92,7 +208,7 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
// Check provider instance allows admin refund
- inst, instErr := s.getOrderProviderInstance(ctx, o)
+ inst, instErr := s.getRefundOrderProviderInstance(ctx, o)
if instErr != nil {
slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr)
return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order")
@@ -217,6 +333,12 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
if err != nil {
return fmt.Errorf("get refund provider: %w", err)
}
+ if err := validateProviderSnapshotMetadata(p.Order, prov.ProviderKey(), providerMerchantIdentityMetadata(prov)); err != nil {
+ s.writeAuditLog(ctx, p.Order.ID, "REFUND_PROVIDER_METADATA_MISMATCH", "admin", map[string]any{
+ "detail": err.Error(),
+ })
+ return err
+ }
_, err = prov.Refund(ctx, payment.RefundRequest{
TradeNo: p.Order.PaymentTradeNo,
OrderID: p.Order.OutTradeNo,
@@ -229,7 +351,14 @@ func (s *PaymentService) gwRefund(ctx context.Context, p *RefundPlan) error {
// getRefundProvider creates a provider using the order's original instance config.
// Delegates to getOrderProvider which handles instance lookup and fallback.
func (s *PaymentService) getRefundProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
- return s.getOrderProvider(ctx, o)
+ inst, err := s.getRefundOrderProviderInstance(ctx, o)
+ if err != nil {
+ return nil, err
+ }
+ if inst == nil {
+ return nil, fmt.Errorf("refund provider instance is unavailable for order %d", o.ID)
+ }
+ return s.createProviderFromInstance(ctx, inst)
}
func (s *PaymentService) handleGwFail(ctx context.Context, p *RefundPlan, gErr error) (*RefundResult, error) {
diff --git a/backend/internal/service/payment_refund_test.go b/backend/internal/service/payment_refund_test.go
new file mode 100644
index 00000000..ca5b62cb
--- /dev/null
+++ b/backend/internal/service/payment_refund_test.go
@@ -0,0 +1,186 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+)
+
+func TestValidateRefundRequestRejectsLegacyGuessedProviderInstance(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("refund-legacy@example.com").
+ SetPasswordHash("hash").
+ SetUsername("refund-legacy-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-refund-instance").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetAllowUserRefund(true).
+ SetRefundEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("REFUND-LEGACY-ORDER").
+ SetOutTradeNo("sub2_refund_legacy_order").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-legacy-refund").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusCompleted).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ }
+
+ _, err = svc.validateRefundRequest(ctx, order.ID, user.ID)
+ require.Error(t, err)
+ require.Equal(t, "USER_REFUND_DISABLED", infraerrors.Reason(err))
+}
+
+func TestPrepareRefundRejectsLegacyGuessedProviderInstance(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("refund-legacy-admin@example.com").
+ SetPasswordHash("hash").
+ SetUsername("refund-legacy-admin-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-refund-admin-instance").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetAllowUserRefund(true).
+ SetRefundEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(188).
+ SetPayAmount(188).
+ SetFeeRate(0).
+ SetRechargeCode("REFUND-LEGACY-ADMIN-ORDER").
+ SetOutTradeNo("sub2_refund_legacy_admin_order").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-legacy-admin-refund").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusCompleted).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ }
+
+ plan, result, err := svc.PrepareRefund(ctx, order.ID, 0, "", false, false)
+ require.Nil(t, plan)
+ require.Nil(t, result)
+ require.Error(t, err)
+ require.Equal(t, "REFUND_DISABLED", infraerrors.Reason(err))
+}
+
+func TestGwRefundRejectsAlipayMerchantIdentitySnapshotMismatch(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("refund-snapshot-mismatch@example.com").
+ SetPasswordHash("hash").
+ SetUsername("refund-snapshot-mismatch-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-refund-mismatch-instance").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{
+ "appId": "runtime-alipay-app",
+ "privateKey": "runtime-private-key",
+ })).
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetRefundEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ instID := strconv.FormatInt(inst.ID, 10)
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("REFUND-SNAPSHOT-MISMATCH-ORDER").
+ SetOutTradeNo("sub2_refund_snapshot_mismatch_order").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-refund-snapshot-mismatch").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusCompleted).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID(instID).
+ SetProviderKey(payment.TypeAlipay).
+ SetProviderSnapshot(map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": instID,
+ "provider_key": payment.TypeAlipay,
+ "merchant_app_id": "expected-alipay-app",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ err = svc.gwRefund(ctx, &RefundPlan{
+ OrderID: order.ID,
+ Order: order,
+ RefundAmount: order.Amount,
+ GatewayAmount: order.Amount,
+ Reason: "snapshot mismatch",
+ })
+ require.ErrorContains(t, err, "alipay app_id mismatch")
+}
diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go
new file mode 100644
index 00000000..1ff061e8
--- /dev/null
+++ b/backend/internal/service/payment_resume_lookup.go
@@ -0,0 +1,67 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) {
+ claims, err := s.paymentResume().ParseToken(strings.TrimSpace(token))
+ if err != nil {
+ return nil, err
+ }
+
+ order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
+ }
+ return nil, fmt.Errorf("get order by resume token: %w", err)
+ }
+ if claims.UserID > 0 && order.UserID != claims.UserID {
+ return nil, invalidResumeTokenMatchError()
+ }
+ snapshot := psOrderProviderSnapshot(order)
+ orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID))
+ orderProviderKey := strings.TrimSpace(psStringValue(order.ProviderKey))
+ if snapshot != nil {
+ if snapshot.ProviderInstanceID != "" {
+ orderProviderInstanceID = snapshot.ProviderInstanceID
+ }
+ if snapshot.ProviderKey != "" {
+ orderProviderKey = snapshot.ProviderKey
+ }
+ }
+ if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID {
+ return nil, invalidResumeTokenMatchError()
+ }
+ if claims.ProviderKey != "" && !strings.EqualFold(orderProviderKey, claims.ProviderKey) {
+ return nil, invalidResumeTokenMatchError()
+ }
+ if claims.PaymentType != "" && NormalizeVisibleMethod(order.PaymentType) != NormalizeVisibleMethod(claims.PaymentType) {
+ return nil, invalidResumeTokenMatchError()
+ }
+ if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
+ result := s.checkPaid(ctx, order)
+ if result == checkPaidResultAlreadyPaid {
+ order, err = s.entClient.PaymentOrder.Get(ctx, order.ID)
+ if err != nil {
+ return nil, fmt.Errorf("reload order by resume token: %w", err)
+ }
+ }
+ }
+
+ return order, nil
+}
+
+func invalidResumeTokenMatchError() error {
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token does not match the payment order")
+}
+
+func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
+ return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token))
+}
diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go
new file mode 100644
index 00000000..a7b5b737
--- /dev/null
+++ b/backend/internal/service/payment_resume_lookup_test.go
@@ -0,0 +1,315 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+)
+
+type paymentResumeLookupProvider struct {
+ queryCount int
+}
+
+func (p *paymentResumeLookupProvider) Name() string { return "resume-lookup-provider" }
+
+func (p *paymentResumeLookupProvider) ProviderKey() string { return payment.TypeAlipay }
+
+func (p *paymentResumeLookupProvider) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.TypeAlipay}
+}
+
+func (p *paymentResumeLookupProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentResumeLookupProvider) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
+ p.queryCount++
+ return &payment.QueryOrderResponse{Status: payment.ProviderStatusPending}, nil
+}
+
+func (p *paymentResumeLookupProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+
+func (p *paymentResumeLookupProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
+func TestGetPublicOrderByResumeTokenReturnsMatchingOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ instanceID := "12"
+ providerKey := payment.TypeEasyPay
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-ORDER").
+ SetOutTradeNo("sub2_resume_lookup").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-1").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID(instanceID).
+ SetProviderKey(providerKey).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ ProviderInstanceID: instanceID,
+ ProviderKey: providerKey,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ resumeService: resumeSvc,
+ }
+
+ got, err := svc.GetPublicOrderByResumeToken(ctx, token)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+}
+
+func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume-mismatch@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-mismatch-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-MISMATCH").
+ SetOutTradeNo("sub2_resume_lookup_mismatch").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-2").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID("12").
+ SetProviderKey(payment.TypeEasyPay).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ ProviderInstanceID: "99",
+ ProviderKey: payment.TypeEasyPay,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ resumeService: resumeSvc,
+ }
+
+ _, err = svc.GetPublicOrderByResumeToken(ctx, token)
+ require.Error(t, err)
+ require.Equal(t, "INVALID_RESUME_TOKEN", infraerrors.Reason(err))
+}
+
+func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume-snapshot-authority@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-snapshot-authority-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-SNAPSHOT-AUTHORITY").
+ SetOutTradeNo("sub2_resume_snapshot_authority").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-snapshot-authority").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID("legacy-column-instance").
+ SetProviderKey(payment.TypeAlipay).
+ SetProviderSnapshot(map[string]any{
+ "schema_version": 2,
+ "provider_instance_id": "snapshot-instance",
+ "provider_key": payment.TypeEasyPay,
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ ProviderInstanceID: "snapshot-instance",
+ ProviderKey: payment.TypeEasyPay,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ resumeService: resumeSvc,
+ }
+
+ got, err := svc.GetPublicOrderByResumeToken(ctx, token)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+}
+
+func TestGetPublicOrderByResumeTokenChecksUpstreamForPendingOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("resume-refresh@example.com").
+ SetPasswordHash("hash").
+ SetUsername("resume-refresh-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("RESUME-PENDING").
+ SetOutTradeNo("sub2_resume_lookup_pending").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-pending").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ provider := &paymentResumeLookupProvider{}
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ resumeService: resumeSvc,
+ providersLoaded: true,
+ }
+
+ got, err := svc.GetPublicOrderByResumeToken(ctx, token)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+ require.Equal(t, 1, provider.queryCount)
+}
+
+func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("public-verify@example.com").
+ SetPasswordHash("hash").
+ SetUsername("public-verify-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("PUBLIC-VERIFY").
+ SetOutTradeNo("sub2_public_verify_pending").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-public-verify").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ provider := &paymentResumeLookupProvider{}
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderPublic(ctx, order.OutTradeNo)
+ require.NoError(t, err)
+ require.Equal(t, order.ID, got.ID)
+ require.Equal(t, 0, provider.queryCount)
+}
+
+func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
+ svc := &PaymentService{
+ entClient: newPaymentConfigServiceTestClient(t),
+ }
+
+ _, err := svc.VerifyOrderPublic(context.Background(), " ")
+ require.Error(t, err)
+ require.Equal(t, "INVALID_OUT_TRADE_NO", infraerrors.Reason(err))
+}
diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go
new file mode 100644
index 00000000..9ae62fde
--- /dev/null
+++ b/backend/internal/service/payment_resume_service.go
@@ -0,0 +1,476 @@
+package service
+
+import (
+ "bytes"
+ "context"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "net"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+const paymentResultReturnPath = "/payment/result"
+
+const (
+ PaymentSourceHostedRedirect = "hosted_redirect"
+ PaymentSourceWechatInAppResume = "wechat_in_app_resume"
+
+ SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source"
+ SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source"
+ SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled"
+ SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled"
+
+ VisibleMethodSourceOfficialAlipay = "official_alipay"
+ VisibleMethodSourceEasyPayAlipay = "easypay_alipay"
+ VisibleMethodSourceOfficialWechat = "official_wxpay"
+ VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
+
+ wechatPaymentResumeTokenType = "wechat_payment_resume"
+
+ paymentResumeNotConfiguredCode = "PAYMENT_RESUME_NOT_CONFIGURED"
+ paymentResumeNotConfiguredMessage = "payment resume tokens require a configured signing key"
+
+ paymentResumeTokenTTL = 24 * time.Hour
+ wechatPaymentResumeTokenTTL = 15 * time.Minute
+)
+
+type ResumeTokenClaims struct {
+ OrderID int64 `json:"oid"`
+ UserID int64 `json:"uid,omitempty"`
+ ProviderInstanceID string `json:"pi,omitempty"`
+ ProviderKey string `json:"pk,omitempty"`
+ PaymentType string `json:"pt,omitempty"`
+ CanonicalReturnURL string `json:"ru,omitempty"`
+ IssuedAt int64 `json:"iat"`
+ ExpiresAt int64 `json:"exp,omitempty"`
+}
+
+type WeChatPaymentResumeClaims struct {
+ TokenType string `json:"tk,omitempty"`
+ OpenID string `json:"openid"`
+ PaymentType string `json:"pt,omitempty"`
+ Amount string `json:"amt,omitempty"`
+ OrderType string `json:"ot,omitempty"`
+ PlanID int64 `json:"pid,omitempty"`
+ RedirectTo string `json:"rd,omitempty"`
+ Scope string `json:"scp,omitempty"`
+ IssuedAt int64 `json:"iat"`
+ ExpiresAt int64 `json:"exp,omitempty"`
+}
+
+type PaymentResumeService struct {
+ signingKey []byte
+ verifyKeys [][]byte
+}
+
+type visibleMethodLoadBalancer struct {
+ inner payment.LoadBalancer
+ configService *PaymentConfigService
+}
+
+func NewPaymentResumeService(signingKey []byte, verifyFallbacks ...[]byte) *PaymentResumeService {
+ svc := &PaymentResumeService{}
+ if len(signingKey) > 0 {
+ svc.signingKey = append([]byte(nil), signingKey...)
+ svc.verifyKeys = append(svc.verifyKeys, svc.signingKey)
+ }
+ for _, fallback := range verifyFallbacks {
+ if len(fallback) == 0 {
+ continue
+ }
+ cloned := append([]byte(nil), fallback...)
+ duplicate := false
+ for _, existing := range svc.verifyKeys {
+ if bytes.Equal(existing, cloned) {
+ duplicate = true
+ break
+ }
+ }
+ if !duplicate {
+ svc.verifyKeys = append(svc.verifyKeys, cloned)
+ }
+ }
+ return svc
+}
+
+func (s *PaymentResumeService) isSigningConfigured() bool {
+ return s != nil && len(s.signingKey) > 0
+}
+
+func (s *PaymentResumeService) ensureSigningKey() error {
+ if s.isSigningConfigured() {
+ return nil
+ }
+ return infraerrors.ServiceUnavailable(paymentResumeNotConfiguredCode, paymentResumeNotConfiguredMessage)
+}
+
+func NormalizeVisibleMethod(method string) string {
+ return payment.GetBasePaymentType(strings.TrimSpace(method))
+}
+
+func NormalizeVisibleMethods(methods []string) []string {
+ if len(methods) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(methods))
+ out := make([]string, 0, len(methods))
+ for _, method := range methods {
+ normalized := NormalizeVisibleMethod(method)
+ if normalized == "" {
+ continue
+ }
+ if _, ok := seen[normalized]; ok {
+ continue
+ }
+ seen[normalized] = struct{}{}
+ out = append(out, normalized)
+ }
+ return out
+}
+
+func NormalizePaymentSource(source string) string {
+ switch strings.TrimSpace(strings.ToLower(source)) {
+ case "", PaymentSourceHostedRedirect:
+ return PaymentSourceHostedRedirect
+ case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume:
+ return PaymentSourceWechatInAppResume
+ default:
+ return strings.TrimSpace(strings.ToLower(source))
+ }
+}
+
+func NormalizeVisibleMethodSource(method, source string) string {
+ switch NormalizeVisibleMethod(method) {
+ case payment.TypeAlipay:
+ switch strings.TrimSpace(strings.ToLower(source)) {
+ case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official":
+ return VisibleMethodSourceOfficialAlipay
+ case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay:
+ return VisibleMethodSourceEasyPayAlipay
+ }
+ case payment.TypeWxpay:
+ switch strings.TrimSpace(strings.ToLower(source)) {
+ case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official":
+ return VisibleMethodSourceOfficialWechat
+ case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay:
+ return VisibleMethodSourceEasyPayWechat
+ }
+ }
+ return ""
+}
+
+func VisibleMethodProviderKeyForSource(method, source string) (string, bool) {
+ switch NormalizeVisibleMethodSource(method, source) {
+ case VisibleMethodSourceOfficialAlipay:
+ return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay
+ case VisibleMethodSourceEasyPayAlipay:
+ return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay
+ case VisibleMethodSourceOfficialWechat:
+ return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay
+ case VisibleMethodSourceEasyPayWechat:
+ return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay
+ default:
+ return "", false
+ }
+}
+
+func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer {
+ if inner == nil || configService == nil || configService.entClient == nil {
+ return inner
+ }
+ return &visibleMethodLoadBalancer{inner: inner, configService: configService}
+}
+
+func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
+ return lb.inner.GetInstanceConfig(ctx, instanceID)
+}
+
+func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) {
+ visibleMethod := NormalizeVisibleMethod(paymentType)
+ if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) {
+ return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount)
+ }
+
+ inst, err := lb.configService.resolveEnabledVisibleMethodInstance(ctx, visibleMethod)
+ if err != nil {
+ return nil, err
+ }
+ if inst == nil {
+ return nil, fmt.Errorf("visible payment method %s has no enabled provider instance", visibleMethod)
+ }
+ return lb.inner.SelectInstance(ctx, inst.ProviderKey, paymentType, strategy, orderAmount)
+}
+
+func visibleMethodEnabledSettingKey(method string) string {
+ switch NormalizeVisibleMethod(method) {
+ case payment.TypeAlipay:
+ return SettingPaymentVisibleMethodAlipayEnabled
+ case payment.TypeWxpay:
+ return SettingPaymentVisibleMethodWxpayEnabled
+ default:
+ return ""
+ }
+}
+
+func visibleMethodSourceSettingKey(method string) string {
+ switch NormalizeVisibleMethod(method) {
+ case payment.TypeAlipay:
+ return SettingPaymentVisibleMethodAlipaySource
+ case payment.TypeWxpay:
+ return SettingPaymentVisibleMethodWxpaySource
+ default:
+ return ""
+ }
+}
+
+func CanonicalizeReturnURL(raw string, srcHost string, srcURL string) (string, error) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return "", nil
+ }
+ parsed, err := url.Parse(raw)
+ if err != nil || !parsed.IsAbs() || parsed.Host == "" {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL")
+ }
+ if parsed.Scheme != "http" && parsed.Scheme != "https" {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https")
+ }
+ parsed.Fragment = ""
+ if parsed.Path == "" {
+ parsed.Path = "/"
+ }
+ if parsed.Path != paymentResultReturnPath {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must target the canonical internal payment result page")
+ }
+ if !allowedReturnURLHost(parsed.Host, srcHost, srcURL) {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use the same host as the current site or browser origin")
+ }
+ return parsed.String(), nil
+}
+
+func allowedReturnURLHost(returnURLHost string, requestHost string, refererURL string) bool {
+ if sameOriginHost(returnURLHost, requestHost) {
+ return true
+ }
+
+ refererURL = strings.TrimSpace(refererURL)
+ if refererURL == "" {
+ return false
+ }
+ parsedReferer, err := url.Parse(refererURL)
+ if err != nil || parsedReferer.Host == "" {
+ return false
+ }
+ return sameOriginHost(returnURLHost, parsedReferer.Host)
+}
+
+func buildPaymentReturnURL(base string, orderID int64, outTradeNo string, resumeToken string) (string, error) {
+ canonical := strings.TrimSpace(base)
+ if canonical == "" {
+ return "", nil
+ }
+
+ parsed, err := url.Parse(canonical)
+ if err != nil {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid URL")
+ }
+ if !parsed.IsAbs() || parsed.Host == "" {
+ return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be a valid absolute URL")
+ }
+ parsed.Fragment = ""
+
+ query := parsed.Query()
+ if orderID > 0 {
+ query.Set("order_id", strconv.FormatInt(orderID, 10))
+ }
+ if strings.TrimSpace(outTradeNo) != "" {
+ query.Set("out_trade_no", strings.TrimSpace(outTradeNo))
+ }
+ if strings.TrimSpace(resumeToken) != "" {
+ query.Set("resume_token", strings.TrimSpace(resumeToken))
+ }
+ query.Set("status", "success")
+ parsed.RawQuery = query.Encode()
+
+ return parsed.String(), nil
+}
+
+func sameOriginHost(returnURLHost string, requestHost string) bool {
+ returnHost := strings.TrimSpace(returnURLHost)
+ reqHost := strings.TrimSpace(requestHost)
+ if returnHost == "" || reqHost == "" {
+ return false
+ }
+ if strings.EqualFold(returnHost, reqHost) {
+ return true
+ }
+
+ returnName, returnPort := splitHostPortDefault(returnHost)
+ reqName, reqPort := splitHostPortDefault(reqHost)
+ if returnName == "" || reqName == "" {
+ return false
+ }
+ return strings.EqualFold(returnName, reqName) && returnPort == reqPort
+}
+
+func splitHostPortDefault(raw string) (string, string) {
+ if host, port, err := net.SplitHostPort(raw); err == nil {
+ return host, port
+ }
+ return raw, ""
+}
+
+func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
+ if err := s.ensureSigningKey(); err != nil {
+ return "", err
+ }
+ if claims.OrderID <= 0 {
+ return "", fmt.Errorf("resume token requires order id")
+ }
+ if claims.IssuedAt == 0 {
+ claims.IssuedAt = time.Now().Unix()
+ }
+ if claims.ExpiresAt == 0 {
+ claims.ExpiresAt = time.Now().Add(paymentResumeTokenTTL).Unix()
+ }
+ return s.createSignedToken(claims)
+}
+
+func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
+ if err := s.ensureSigningKey(); err != nil {
+ return nil, err
+ }
+ var claims ResumeTokenClaims
+ if err := s.parseSignedToken(token, &claims); err != nil {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
+ }
+ if claims.OrderID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
+ }
+ if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_RESUME_TOKEN", "resume token has expired"); err != nil {
+ return nil, err
+ }
+ return &claims, nil
+}
+
+func (s *PaymentResumeService) CreateWeChatPaymentResumeToken(claims WeChatPaymentResumeClaims) (string, error) {
+ if err := s.ensureSigningKey(); err != nil {
+ return "", err
+ }
+ claims.OpenID = strings.TrimSpace(claims.OpenID)
+ if claims.OpenID == "" {
+ return "", fmt.Errorf("wechat payment resume token requires openid")
+ }
+ if claims.IssuedAt == 0 {
+ claims.IssuedAt = time.Now().Unix()
+ }
+ if claims.ExpiresAt == 0 {
+ claims.ExpiresAt = time.Now().Add(wechatPaymentResumeTokenTTL).Unix()
+ }
+ if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
+ claims.PaymentType = normalized
+ }
+ if claims.PaymentType == "" {
+ claims.PaymentType = payment.TypeWxpay
+ }
+ if claims.OrderType == "" {
+ claims.OrderType = payment.OrderTypeBalance
+ }
+ claims.TokenType = wechatPaymentResumeTokenType
+ return s.createSignedToken(claims)
+}
+
+func (s *PaymentResumeService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) {
+ if err := s.ensureSigningKey(); err != nil {
+ return nil, err
+ }
+ var claims WeChatPaymentResumeClaims
+ if err := s.parseSignedToken(token, &claims); err != nil {
+ return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payload is invalid")
+ }
+ if claims.TokenType != wechatPaymentResumeTokenType {
+ return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token type mismatch")
+ }
+ claims.OpenID = strings.TrimSpace(claims.OpenID)
+ if claims.OpenID == "" {
+ return nil, infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
+ }
+ if err := validatePaymentResumeExpiry(claims.ExpiresAt, "INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token has expired"); err != nil {
+ return nil, err
+ }
+ if normalized := NormalizeVisibleMethod(claims.PaymentType); normalized != "" {
+ claims.PaymentType = normalized
+ }
+ if claims.PaymentType == "" {
+ claims.PaymentType = payment.TypeWxpay
+ }
+ if claims.OrderType == "" {
+ claims.OrderType = payment.OrderTypeBalance
+ }
+ return &claims, nil
+}
+
+func (s *PaymentResumeService) createSignedToken(claims any) (string, error) {
+ payload, err := json.Marshal(claims)
+ if err != nil {
+ return "", fmt.Errorf("marshal resume claims: %w", err)
+ }
+ encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
+ return encodedPayload + "." + s.sign(encodedPayload), nil
+}
+
+func (s *PaymentResumeService) parseSignedToken(token string, dest any) error {
+ parts := strings.Split(token, ".")
+ if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
+ }
+ if !s.verifySignature(parts[0], parts[1]) {
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
+ }
+ payload, err := base64.RawURLEncoding.DecodeString(parts[0])
+ if err != nil {
+ return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
+ }
+ return json.Unmarshal(payload, dest)
+}
+
+func (s *PaymentResumeService) verifySignature(payload string, signature string) bool {
+ if s == nil {
+ return false
+ }
+ for _, key := range s.verifyKeys {
+ if hmac.Equal([]byte(signature), []byte(signPaymentResumePayload(payload, key))) {
+ return true
+ }
+ }
+ return false
+}
+
+func validatePaymentResumeExpiry(expiresAt int64, code, message string) error {
+ if expiresAt <= 0 {
+ return nil
+ }
+ if time.Now().Unix() > expiresAt {
+ return infraerrors.BadRequest(code, message)
+ }
+ return nil
+}
+
+func (s *PaymentResumeService) sign(payload string) string {
+ return signPaymentResumePayload(payload, s.signingKey)
+}
+
+func signPaymentResumePayload(payload string, key []byte) string {
+ mac := hmac.New(sha256.New, key)
+ _, _ = mac.Write([]byte(payload))
+ return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+}
diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go
new file mode 100644
index 00000000..7e0adc2d
--- /dev/null
+++ b/backend/internal/service/payment_resume_service_test.go
@@ -0,0 +1,808 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "net/url"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+func TestNormalizeVisibleMethods(t *testing.T) {
+ t.Parallel()
+
+ got := NormalizeVisibleMethods([]string{
+ "alipay_direct",
+ "alipay",
+ " wxpay_direct ",
+ "wxpay",
+ "stripe",
+ })
+
+ want := []string{"alipay", "wxpay", "stripe"}
+ if len(got) != len(want) {
+ t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
+ }
+ }
+}
+
+func TestNormalizePaymentSource(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ expect string
+ }{
+ {name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect},
+ {name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume},
+ {name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := NormalizePaymentSource(tt.input); got != tt.expect {
+ t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect)
+ }
+ })
+ }
+}
+
+func TestCanonicalizeReturnURL(t *testing.T) {
+ t.Parallel()
+
+ got, err := CanonicalizeReturnURL("https://example.com/payment/result?b=2#a", "example.com", "")
+ if err != nil {
+ t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
+ }
+ if got != "https://example.com/payment/result?b=2" {
+ t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/payment/result?b=2")
+ }
+}
+
+func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
+ t.Parallel()
+
+ if _, err := CanonicalizeReturnURL("/payment/result", "example.com", ""); err == nil {
+ t.Fatal("CanonicalizeReturnURL should reject relative URLs")
+ }
+}
+
+func TestCanonicalizeReturnURLRejectsExternalHost(t *testing.T) {
+ t.Parallel()
+
+ if _, err := CanonicalizeReturnURL("https://evil.example/payment/result", "app.example.com", ""); err == nil {
+ t.Fatal("CanonicalizeReturnURL should reject external hosts")
+ }
+}
+
+func TestCanonicalizeReturnURLAllowsConfiguredFrontendHost(t *testing.T) {
+ t.Parallel()
+
+ got, err := CanonicalizeReturnURL(
+ "https://app.example.com/payment/result?from=checkout",
+ "api.example.com",
+ "https://app.example.com/purchase",
+ )
+ if err != nil {
+ t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
+ }
+ if got != "https://app.example.com/payment/result?from=checkout" {
+ t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://app.example.com/payment/result?from=checkout")
+ }
+}
+
+func TestCanonicalizeReturnURLRejectsNonCanonicalPath(t *testing.T) {
+ t.Parallel()
+
+ if _, err := CanonicalizeReturnURL("https://app.example.com/orders/42", "app.example.com", ""); err == nil {
+ t.Fatal("CanonicalizeReturnURL should reject non-canonical result paths")
+ }
+}
+
+func TestBuildPaymentReturnURL(t *testing.T) {
+ t.Parallel()
+
+ got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "sub2_42", "resume-token")
+ if err != nil {
+ t.Fatalf("buildPaymentReturnURL returned error: %v", err)
+ }
+
+ parsed, err := url.Parse(got)
+ if err != nil {
+ t.Fatalf("url.Parse returned error: %v", err)
+ }
+ if parsed.Fragment != "" {
+ t.Fatalf("buildPaymentReturnURL should strip fragments, got %q", parsed.Fragment)
+ }
+ query := parsed.Query()
+ if query.Get("from") != "checkout" {
+ t.Fatalf("expected original query to be preserved, got %q", query.Get("from"))
+ }
+ if query.Get("order_id") != strconv.FormatInt(42, 10) {
+ t.Fatalf("order_id = %q", query.Get("order_id"))
+ }
+ if query.Get("out_trade_no") != "sub2_42" {
+ t.Fatalf("out_trade_no = %q", query.Get("out_trade_no"))
+ }
+ if query.Get("resume_token") != "resume-token" {
+ t.Fatalf("resume_token = %q", query.Get("resume_token"))
+ }
+ if query.Get("status") != "success" {
+ t.Fatalf("status = %q", query.Get("status"))
+ }
+}
+
+func TestBuildPaymentReturnURLWithoutResumeTokenStillIncludesOutTradeNo(t *testing.T) {
+ t.Parallel()
+
+ got, err := buildPaymentReturnURL("https://example.com/payment/result", 42, "sub2_42", "")
+ if err != nil {
+ t.Fatalf("buildPaymentReturnURL returned error: %v", err)
+ }
+
+ parsed, err := url.Parse(got)
+ if err != nil {
+ t.Fatalf("url.Parse returned error: %v", err)
+ }
+ query := parsed.Query()
+ if query.Get("order_id") != "42" {
+ t.Fatalf("order_id = %q", query.Get("order_id"))
+ }
+ if query.Get("out_trade_no") != "sub2_42" {
+ t.Fatalf("out_trade_no = %q", query.Get("out_trade_no"))
+ }
+ if query.Get("resume_token") != "" {
+ t.Fatalf("resume_token = %q, want empty", query.Get("resume_token"))
+ }
+}
+
+func TestBuildPaymentReturnURLEmptyBase(t *testing.T) {
+ t.Parallel()
+
+ got, err := buildPaymentReturnURL("", 42, "sub2_42", "resume-token")
+ if err != nil {
+ t.Fatalf("buildPaymentReturnURL returned error: %v", err)
+ }
+ if got != "" {
+ t.Fatalf("buildPaymentReturnURL = %q, want empty string", got)
+ }
+}
+
+func TestPaymentResumeTokenRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateToken(ResumeTokenClaims{
+ OrderID: 42,
+ UserID: 7,
+ ProviderInstanceID: "19",
+ ProviderKey: "easypay",
+ PaymentType: "wxpay",
+ CanonicalReturnURL: "https://example.com/payment/result",
+ IssuedAt: 1234567890,
+ })
+ if err != nil {
+ t.Fatalf("CreateToken returned error: %v", err)
+ }
+
+ claims, err := svc.ParseToken(token)
+ if err != nil {
+ t.Fatalf("ParseToken returned error: %v", err)
+ }
+ if claims.OrderID != 42 || claims.UserID != 7 {
+ t.Fatalf("claims mismatch: %+v", claims)
+ }
+ if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" {
+ t.Fatalf("claims provider snapshot mismatch: %+v", claims)
+ }
+ if claims.CanonicalReturnURL != "https://example.com/payment/result" {
+ t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL)
+ }
+}
+
+func TestCreateTokenRejectsMissingSigningKey(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService(nil)
+ _, err := svc.CreateToken(ResumeTokenClaims{OrderID: 42})
+ if err == nil {
+ t.Fatal("CreateToken should reject missing signing key")
+ }
+}
+
+func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
+ t.Parallel()
+
+ token := mustCreateFallbackSignedToken(t, ResumeTokenClaims{OrderID: 42, UserID: 7})
+ svc := NewPaymentResumeService(nil)
+ _, err := svc.ParseToken(token)
+ if err == nil {
+ t.Fatal("ParseToken should reject tokens when signing key is missing")
+ }
+}
+
+func TestParseTokenRejectsExpiredToken(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateToken(ResumeTokenClaims{
+ OrderID: 42,
+ UserID: 7,
+ IssuedAt: time.Now().Add(-25 * time.Hour).Unix(),
+ ExpiresAt: time.Now().Add(-1 * time.Hour).Unix(),
+ })
+ if err != nil {
+ t.Fatalf("CreateToken returned error: %v", err)
+ }
+
+ _, err = svc.ParseToken(token)
+ if err == nil {
+ t.Fatal("ParseToken should reject expired tokens")
+ }
+}
+
+func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeSubscription,
+ PlanID: 7,
+ RedirectTo: "/purchase?from=wechat",
+ Scope: "snsapi_base",
+ IssuedAt: 1234567890,
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ claims, err := svc.ParseWeChatPaymentResumeToken(token)
+ if err != nil {
+ t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
+ }
+ if claims.OpenID != "openid-123" || claims.PaymentType != payment.TypeWxpay {
+ t.Fatalf("claims mismatch: %+v", claims)
+ }
+ if claims.Amount != "12.50" || claims.OrderType != payment.OrderTypeSubscription || claims.PlanID != 7 {
+ t.Fatalf("claims payment context mismatch: %+v", claims)
+ }
+ if claims.RedirectTo != "/purchase?from=wechat" || claims.Scope != "snsapi_base" {
+ t.Fatalf("claims redirect/scope mismatch: %+v", claims)
+ }
+}
+
+func TestCreateWeChatPaymentResumeTokenRejectsMissingSigningKey(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService(nil)
+ _, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{OpenID: "openid-123"})
+ if err == nil {
+ t.Fatal("CreateWeChatPaymentResumeToken should reject missing signing key")
+ }
+}
+
+func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
+ t.Parallel()
+
+ token := mustCreateFallbackSignedToken(t, WeChatPaymentResumeClaims{
+ TokenType: wechatPaymentResumeTokenType,
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ })
+ svc := NewPaymentResumeService(nil)
+ _, err := svc.ParseWeChatPaymentResumeToken(token)
+ if err == nil {
+ t.Fatal("ParseWeChatPaymentResumeToken should reject tokens when signing key is missing")
+ }
+}
+
+func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) {
+ t.Parallel()
+
+ svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ IssuedAt: time.Now().Add(-30 * time.Minute).Unix(),
+ ExpiresAt: time.Now().Add(-1 * time.Minute).Unix(),
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ _, err = svc.ParseWeChatPaymentResumeToken(token)
+ if err == nil {
+ t.Fatal("ParseWeChatPaymentResumeToken should reject expired tokens")
+ }
+}
+
+func TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey(t *testing.T) {
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
+
+ token, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-explicit-key",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ svc := &PaymentService{
+ configService: &PaymentConfigService{
+ encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
+ },
+ }
+
+ claims, err := svc.ParseWeChatPaymentResumeToken(token)
+ if err != nil {
+ t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
+ }
+ if claims.OpenID != "openid-explicit-key" {
+ t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-explicit-key")
+ }
+}
+
+func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration(t *testing.T) {
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
+
+ legacyKey := []byte("0123456789abcdef0123456789abcdef")
+ token, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-legacy-key",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ svc := &PaymentService{
+ configService: &PaymentConfigService{
+ encryptionKey: legacyKey,
+ },
+ }
+
+ claims, err := svc.ParseWeChatPaymentResumeToken(token)
+ if err != nil {
+ t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
+ }
+ if claims.OpenID != "openid-legacy-key" {
+ t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-legacy-key")
+ }
+}
+
+func TestNewConfiguredPaymentResumeServicePrefersExplicitSigningKeyAndKeepsLegacyVerificationFallback(t *testing.T) {
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key")
+
+ legacyKey := []byte("0123456789abcdef0123456789abcdef")
+ svc := newLegacyAwarePaymentResumeService(legacyKey)
+
+ explicitToken, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-explicit-key",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ explicitClaims, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).ParseWeChatPaymentResumeToken(explicitToken)
+ if err != nil {
+ t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
+ }
+ if explicitClaims.OpenID != "openid-explicit-key" {
+ t.Fatalf("openid = %q, want %q", explicitClaims.OpenID, "openid-explicit-key")
+ }
+
+ legacyToken, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
+ OpenID: "openid-legacy-key",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
+ }
+
+ legacyClaims, err := svc.ParseWeChatPaymentResumeToken(legacyToken)
+ if err != nil {
+ t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
+ }
+ if legacyClaims.OpenID != "openid-legacy-key" {
+ t.Fatalf("openid = %q, want %q", legacyClaims.OpenID, "openid-legacy-key")
+ }
+}
+
+func TestNormalizeVisibleMethodSource(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ method string
+ input string
+ want string
+ }{
+ {name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay},
+ {name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay},
+ {name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat},
+ {name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat},
+ {name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want {
+ t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestVisibleMethodProviderKeyForSource(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ method string
+ source string
+ want string
+ ok bool
+ }{
+ {name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true},
+ {name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true},
+ {name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true},
+ {name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true},
+ {name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source)
+ if got != tt.want || ok != tt.ok {
+ t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok)
+ }
+ })
+ }
+}
+
+func TestVisibleMethodLoadBalancerUsesEnabledProviderInstance(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Official Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create alipay provider: %v", err)
+ }
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ entClient: client,
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ _, err = lb.SelectInstance(ctx, "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5)
+ if err != nil {
+ t.Fatalf("SelectInstance returned error: %v", err)
+ }
+ if inner.lastProviderKey != payment.TypeAlipay {
+ t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay)
+ }
+}
+
+func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabled(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ method payment.PaymentType
+ officialName string
+ officialTypes string
+ easyPayName string
+ easyPayTypes string
+ sourceSetting string
+ wantProvider string
+ }{
+ {
+ name: "alipay uses official source",
+ method: payment.TypeAlipay,
+ officialName: "Official Alipay",
+ officialTypes: "alipay",
+ easyPayName: "EasyPay Alipay",
+ easyPayTypes: "alipay",
+ sourceSetting: VisibleMethodSourceOfficialAlipay,
+ wantProvider: payment.TypeAlipay,
+ },
+ {
+ name: "alipay uses easypay source",
+ method: payment.TypeAlipay,
+ officialName: "Official Alipay",
+ officialTypes: "alipay",
+ easyPayName: "EasyPay Alipay",
+ easyPayTypes: "alipay",
+ sourceSetting: VisibleMethodSourceEasyPayAlipay,
+ wantProvider: payment.TypeEasyPay,
+ },
+ {
+ name: "wxpay uses official source",
+ method: payment.TypeWxpay,
+ officialName: "Official WeChat",
+ officialTypes: "wxpay",
+ easyPayName: "EasyPay WeChat",
+ easyPayTypes: "wxpay",
+ sourceSetting: VisibleMethodSourceOfficialWechat,
+ wantProvider: payment.TypeWxpay,
+ },
+ {
+ name: "wxpay uses easypay source",
+ method: payment.TypeWxpay,
+ officialName: "Official WeChat",
+ officialTypes: "wxpay",
+ easyPayName: "EasyPay WeChat",
+ easyPayTypes: "wxpay",
+ sourceSetting: VisibleMethodSourceEasyPayWechat,
+ wantProvider: payment.TypeEasyPay,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ officialProviderKey := payment.TypeAlipay
+ if tt.method == payment.TypeWxpay {
+ officialProviderKey = payment.TypeWxpay
+ }
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(officialProviderKey).
+ SetName(tt.officialName).
+ SetConfig("{}").
+ SetSupportedTypes(tt.officialTypes).
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official provider: %v", err)
+ }
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName(tt.easyPayName).
+ SetConfig("{}").
+ SetSupportedTypes(tt.easyPayTypes).
+ SetEnabled(true).
+ SetSortOrder(2).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay provider: %v", err)
+ }
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ visibleMethodSourceSettingKey(tt.method): tt.sourceSetting,
+ },
+ },
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ _, err = lb.SelectInstance(ctx, "", tt.method, payment.StrategyRoundRobin, 12.5)
+ if err != nil {
+ t.Fatalf("SelectInstance returned error: %v", err)
+ }
+ if inner.lastProviderKey != tt.wantProvider {
+ t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, tt.wantProvider)
+ }
+ })
+ }
+}
+
+func TestVisibleMethodLoadBalancerPreservesLegacyCrossProviderRoutingWhenSourceMissing(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("Official Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official provider: %v", err)
+ }
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("EasyPay Alipay").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ SetSortOrder(2).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay provider: %v", err)
+ }
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ visibleMethodSourceSettingKey(payment.TypeAlipay): "",
+ },
+ },
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ _, err = lb.SelectInstance(ctx, "", payment.TypeAlipay, payment.StrategyRoundRobin, 9.9)
+ if err != nil {
+ t.Fatalf("SelectInstance returned error: %v", err)
+ }
+ if inner.lastProviderKey != "" {
+ t.Fatalf("lastProviderKey = %q, want legacy cross-provider empty key", inner.lastProviderKey)
+ }
+ if inner.lastPaymentType != payment.TypeAlipay {
+ t.Fatalf("lastPaymentType = %q, want %q", inner.lastPaymentType, payment.TypeAlipay)
+ }
+}
+
+func TestVisibleMethodLoadBalancerRejectsInvalidSourceWhenMultipleProvidersEnabled(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ method payment.PaymentType
+ sourceValue string
+ wantMessage string
+ }{
+ {
+ name: "invalid wxpay source",
+ method: payment.TypeWxpay,
+ sourceValue: "stripe",
+ wantMessage: "wxpay source must be one of the supported payment providers",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+
+ officialProviderKey := payment.TypeAlipay
+ officialSupportedTypes := "alipay"
+ officialName := "Official Alipay"
+ easyPaySupportedTypes := "alipay"
+ easyPayName := "EasyPay Alipay"
+ if tt.method == payment.TypeWxpay {
+ officialProviderKey = payment.TypeWxpay
+ officialSupportedTypes = "wxpay"
+ officialName = "Official WeChat"
+ easyPaySupportedTypes = "wxpay"
+ easyPayName = "EasyPay WeChat"
+ }
+
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(officialProviderKey).
+ SetName(officialName).
+ SetConfig("{}").
+ SetSupportedTypes(officialSupportedTypes).
+ SetEnabled(true).
+ SetSortOrder(1).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create official provider: %v", err)
+ }
+
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName(easyPayName).
+ SetConfig("{}").
+ SetSupportedTypes(easyPaySupportedTypes).
+ SetEnabled(true).
+ SetSortOrder(2).
+ Save(ctx)
+ if err != nil {
+ t.Fatalf("create easypay provider: %v", err)
+ }
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ entClient: client,
+ settingRepo: &paymentConfigSettingRepoStub{
+ values: map[string]string{
+ visibleMethodSourceSettingKey(tt.method): tt.sourceValue,
+ },
+ },
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ _, err = lb.SelectInstance(ctx, "", tt.method, payment.StrategyRoundRobin, 9.9)
+ if err == nil {
+ t.Fatal("SelectInstance should reject invalid visible method source configuration")
+ }
+ if infraerrors.Reason(err) != "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE" {
+ t.Fatalf("Reason(err) = %q, want %q", infraerrors.Reason(err), "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE")
+ }
+ if infraerrors.Message(err) != tt.wantMessage {
+ t.Fatalf("Message(err) = %q, want %q", infraerrors.Message(err), tt.wantMessage)
+ }
+ })
+ }
+}
+
+func TestVisibleMethodLoadBalancerRejectsMissingEnabledVisibleMethodProvider(t *testing.T) {
+ t.Parallel()
+
+ inner := &captureLoadBalancer{}
+ configService := &PaymentConfigService{
+ entClient: newPaymentConfigServiceTestClient(t),
+ }
+ lb := newVisibleMethodLoadBalancer(inner, configService)
+
+ if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil {
+ t.Fatal("SelectInstance should reject when no enabled provider instance exists")
+ }
+}
+
+type captureLoadBalancer struct {
+ lastProviderKey string
+ lastPaymentType string
+}
+
+func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) {
+ return map[string]string{}, nil
+}
+
+func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) {
+ c.lastProviderKey = providerKey
+ c.lastPaymentType = paymentType
+ return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
+}
+
+func mustCreateFallbackSignedToken(t *testing.T, claims any) string {
+ t.Helper()
+
+ payload, err := json.Marshal(claims)
+ if err != nil {
+ t.Fatalf("marshal claims: %v", err)
+ }
+ encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
+ mac := hmac.New(sha256.New, []byte("sub2api-payment-resume"))
+ _, _ = mac.Write([]byte(encodedPayload))
+ signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+ return encodedPayload + "." + signature
+}
diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go
index 6fc23f97..aa121e41 100644
--- a/backend/internal/service/payment_service.go
+++ b/backend/internal/service/payment_service.go
@@ -1,15 +1,18 @@
package service
import (
+ "bytes"
"context"
+ "encoding/hex"
"fmt"
"log/slog"
"math/rand/v2"
+ "os"
+ "strings"
"sync"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
- "github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
@@ -45,6 +48,8 @@ const (
orderIDPrefix = "sub2_"
)
+const paymentResumeSigningKeyEnv = "PAYMENT_RESUME_SIGNING_KEY"
+
// --- Types ---
// generateOutTradeNo creates a unique external order ID for payment providers.
@@ -65,29 +70,39 @@ func generateRandomString(n int) string {
}
type CreateOrderRequest struct {
- UserID int64
- Amount float64
- PaymentType string
- ClientIP string
- IsMobile bool
- SrcHost string
- SrcURL string
- OrderType string
- PlanID int64
+ UserID int64
+ Amount float64
+ PaymentType string
+ OpenID string
+ ClientIP string
+ IsMobile bool
+ IsWeChatBrowser bool
+ SrcHost string
+ SrcURL string
+ ReturnURL string
+ PaymentSource string
+ OrderType string
+ PlanID int64
}
type CreateOrderResponse struct {
- OrderID int64 `json:"order_id"`
- Amount float64 `json:"amount"`
- PayAmount float64 `json:"pay_amount"`
- FeeRate float64 `json:"fee_rate"`
- Status string `json:"status"`
- PaymentType string `json:"payment_type"`
- PayURL string `json:"pay_url,omitempty"`
- QRCode string `json:"qr_code,omitempty"`
- ClientSecret string `json:"client_secret,omitempty"`
- ExpiresAt time.Time `json:"expires_at"`
- PaymentMode string `json:"payment_mode,omitempty"`
+ OrderID int64 `json:"order_id"`
+ Amount float64 `json:"amount"`
+ PayAmount float64 `json:"pay_amount"`
+ FeeRate float64 `json:"fee_rate"`
+ Status string `json:"status"`
+ ResultType payment.CreatePaymentResultType `json:"result_type,omitempty"`
+ PaymentType string `json:"payment_type"`
+ OutTradeNo string `json:"out_trade_no,omitempty"`
+ PayURL string `json:"pay_url,omitempty"`
+ QRCode string `json:"qr_code,omitempty"`
+ ClientSecret string `json:"client_secret,omitempty"`
+ OAuth *payment.WechatOAuthInfo `json:"oauth,omitempty"`
+ JSAPI *payment.WechatJSAPIPayload `json:"jsapi,omitempty"`
+ JSAPIPayload *payment.WechatJSAPIPayload `json:"jsapi_payload,omitempty"`
+ ExpiresAt time.Time `json:"expires_at"`
+ PaymentMode string `json:"payment_mode,omitempty"`
+ ResumeToken string `json:"resume_token,omitempty"`
}
type OrderListParams struct {
@@ -155,20 +170,24 @@ type TopUserStat struct {
// --- Service ---
type PaymentService struct {
- providerMu sync.Mutex
- providersLoaded bool
- entClient *dbent.Client
- registry *payment.Registry
- loadBalancer payment.LoadBalancer
- redeemService *RedeemService
- subscriptionSvc *SubscriptionService
- configService *PaymentConfigService
- userRepo UserRepository
- groupRepo GroupRepository
+ providerMu sync.Mutex
+ providersLoaded bool
+ entClient *dbent.Client
+ registry *payment.Registry
+ loadBalancer payment.LoadBalancer
+ redeemService *RedeemService
+ subscriptionSvc *SubscriptionService
+ configService *PaymentConfigService
+ userRepo UserRepository
+ groupRepo GroupRepository
+ resumeService *PaymentResumeService
+ affiliateService *AffiliateService
}
-func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
- return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
+func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService {
+ svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo, affiliateService: affiliateService}
+ svc.resumeService = psNewPaymentResumeService(configService)
+ return svc
}
// --- Provider Registry ---
@@ -219,25 +238,6 @@ func (s *PaymentService) loadProviders(ctx context.Context) {
}
}
-// GetWebhookProvider returns the provider instance that should verify a webhook.
-// It extracts out_trade_no from the raw body, looks up the order to find the
-// original provider instance, and creates a provider with that instance's credentials.
-// Falls back to the registry provider when the order cannot be found.
-func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
- if outTradeNo != "" {
- order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
- if err == nil {
- p, pErr := s.getOrderProvider(ctx, order)
- if pErr == nil {
- return p, nil
- }
- slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr)
- }
- }
- s.EnsureProviders(ctx)
- return s.registry.GetProviderByKey(providerKey)
-}
-
// --- Helpers ---
func psIsRefundStatus(s string) bool {
@@ -262,6 +262,60 @@ func psNilIfEmpty(s string) *string {
return &s
}
+func (s *PaymentService) paymentResume() *PaymentResumeService {
+ if s.resumeService != nil {
+ return s.resumeService
+ }
+ return psNewPaymentResumeService(s.configService)
+}
+
+func NewLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService {
+ return newLegacyAwarePaymentResumeService(legacyKey)
+}
+
+func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService {
+ return newLegacyAwarePaymentResumeService(psResumeLegacyVerificationKey(configService))
+}
+
+func newLegacyAwarePaymentResumeService(legacyKey []byte) *PaymentResumeService {
+ signingKey, verifyFallbacks := resolvePaymentResumeSigningKeys(legacyKey)
+ return NewPaymentResumeService(signingKey, verifyFallbacks...)
+}
+
+func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte {
+ if configService == nil {
+ return nil
+ }
+ return configService.encryptionKey
+}
+
+func resolvePaymentResumeSigningKeys(legacyKey []byte) ([]byte, [][]byte) {
+ signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv))
+ if len(signingKey) == 0 {
+ if len(legacyKey) == 0 {
+ return nil, nil
+ }
+ return legacyKey, nil
+ }
+ if len(legacyKey) == 0 || bytes.Equal(legacyKey, signingKey) {
+ return signingKey, nil
+ }
+ return signingKey, [][]byte{legacyKey}
+}
+
+func parsePaymentResumeSigningKey(raw string) []byte {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return nil
+ }
+ if len(raw) >= 64 && len(raw)%2 == 0 {
+ if decoded, err := hex.DecodeString(raw); err == nil && len(decoded) > 0 {
+ return decoded
+ }
+ }
+ return []byte(raw)
+}
+
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {
diff --git a/backend/internal/service/payment_visible_method_instances.go b/backend/internal/service/payment_visible_method_instances.go
new file mode 100644
index 00000000..899bd7a0
--- /dev/null
+++ b/backend/internal/service/payment_visible_method_instances.go
@@ -0,0 +1,242 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+func enabledVisibleMethodsForProvider(providerKey, supportedTypes string) []string {
+ methodSet := make(map[string]struct{}, 2)
+ addMethod := func(method string) {
+ method = NormalizeVisibleMethod(method)
+ switch method {
+ case payment.TypeAlipay, payment.TypeWxpay:
+ methodSet[method] = struct{}{}
+ }
+ }
+
+ switch strings.TrimSpace(providerKey) {
+ case payment.TypeAlipay:
+ if strings.TrimSpace(supportedTypes) == "" {
+ addMethod(payment.TypeAlipay)
+ break
+ }
+ for _, supportedType := range splitTypes(supportedTypes) {
+ if NormalizeVisibleMethod(supportedType) == payment.TypeAlipay {
+ addMethod(payment.TypeAlipay)
+ break
+ }
+ }
+ case payment.TypeWxpay:
+ if strings.TrimSpace(supportedTypes) == "" {
+ addMethod(payment.TypeWxpay)
+ break
+ }
+ for _, supportedType := range splitTypes(supportedTypes) {
+ if NormalizeVisibleMethod(supportedType) == payment.TypeWxpay {
+ addMethod(payment.TypeWxpay)
+ break
+ }
+ }
+ case payment.TypeEasyPay:
+ for _, supportedType := range splitTypes(supportedTypes) {
+ addMethod(supportedType)
+ }
+ }
+
+ methods := make([]string, 0, len(methodSet))
+ for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} {
+ if _, ok := methodSet[method]; ok {
+ methods = append(methods, method)
+ }
+ }
+ return methods
+}
+
+func providerSupportsVisibleMethod(inst *dbent.PaymentProviderInstance, method string) bool {
+ if inst == nil || !inst.Enabled {
+ return false
+ }
+ method = NormalizeVisibleMethod(method)
+ for _, candidate := range enabledVisibleMethodsForProvider(inst.ProviderKey, inst.SupportedTypes) {
+ if candidate == method {
+ return true
+ }
+ }
+ return false
+}
+
+func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInstance, method string) []*dbent.PaymentProviderInstance {
+ filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
+ for _, inst := range instances {
+ if providerSupportsVisibleMethod(inst, method) {
+ filtered = append(filtered, inst)
+ }
+ }
+ return filtered
+}
+
+func filterVisibleMethodInstancesByProviderKey(instances []*dbent.PaymentProviderInstance, method string, providerKey string) []*dbent.PaymentProviderInstance {
+ filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances))
+ for _, inst := range instances {
+ if !providerSupportsVisibleMethod(inst, method) {
+ continue
+ }
+ if !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), strings.TrimSpace(providerKey)) {
+ continue
+ }
+ filtered = append(filtered, inst)
+ }
+ return filtered
+}
+
+func distinctVisibleMethodProviderKeys(instances []*dbent.PaymentProviderInstance) []string {
+ seen := make(map[string]struct{}, len(instances))
+ keys := make([]string, 0, len(instances))
+ for _, inst := range instances {
+ if inst == nil {
+ continue
+ }
+ key := strings.TrimSpace(inst.ProviderKey)
+ if key == "" {
+ continue
+ }
+ normalized := strings.ToLower(key)
+ if _, ok := seen[normalized]; ok {
+ continue
+ }
+ seen[normalized] = struct{}{}
+ keys = append(keys, key)
+ }
+ return keys
+}
+
+func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance {
+ providerKey = strings.TrimSpace(providerKey)
+ if providerKey == "" {
+ return nil
+ }
+ for _, inst := range instances {
+ if strings.EqualFold(strings.TrimSpace(inst.ProviderKey), providerKey) {
+ return inst
+ }
+ }
+ return nil
+}
+
+func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts(
+ ctx context.Context,
+ excludeID int64,
+ providerKey string,
+ supportedTypes string,
+ enabled bool,
+) error {
+ // Visible methods are selected by configured source (official/easypay),
+ // so multiple enabled providers can intentionally claim the same user-facing
+ // method. Order creation and limits will route through the configured source.
+ _, _, _, _, _ = ctx, excludeID, providerKey, supportedTypes, enabled
+ return nil
+}
+
+func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context.Context, method string) (string, error) {
+ method = NormalizeVisibleMethod(method)
+ sourceKey := visibleMethodSourceSettingKey(method)
+ rawSource := ""
+ if s != nil && s.settingRepo != nil && sourceKey != "" {
+ value, err := s.settingRepo.GetValue(ctx, sourceKey)
+ if err != nil {
+ if !errors.Is(err, ErrSettingNotFound) {
+ return "", fmt.Errorf("get %s: %w", sourceKey, err)
+ }
+ } else {
+ rawSource = value
+ }
+ }
+
+ normalizedSource, err := normalizeVisibleMethodSettingSource(method, rawSource, true)
+ if err != nil {
+ return "", err
+ }
+ if normalizedSource == "" {
+ return "", nil
+ }
+ providerKey, ok := VisibleMethodProviderKeyForSource(method, normalizedSource)
+ if !ok {
+ return "", infraerrors.BadRequest(
+ "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
+ fmt.Sprintf("%s source must be one of the supported payment providers", method),
+ )
+ }
+ return providerKey, nil
+}
+
+func (s *PaymentConfigService) resolveVisibleMethodProviderKey(
+ ctx context.Context,
+ method string,
+ matching []*dbent.PaymentProviderInstance,
+) (string, error) {
+ switch providerKeys := distinctVisibleMethodProviderKeys(matching); len(providerKeys) {
+ case 0:
+ return "", nil
+ case 1:
+ return strings.TrimSpace(providerKeys[0]), nil
+ default:
+ providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method)
+ if err != nil {
+ return "", err
+ }
+ if providerKey == "" {
+ return "", nil
+ }
+ selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey)
+ if selected == nil {
+ return "", infraerrors.BadRequest(
+ "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
+ fmt.Sprintf("%s source has no enabled provider instance", method),
+ )
+ }
+ return strings.TrimSpace(selected.ProviderKey), nil
+ }
+}
+
+func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance(
+ ctx context.Context,
+ method string,
+) (*dbent.PaymentProviderInstance, error) {
+ if s == nil || s.entClient == nil {
+ return nil, nil
+ }
+
+ method = NormalizeVisibleMethod(method)
+ if method != payment.TypeAlipay && method != payment.TypeWxpay {
+ return nil, nil
+ }
+
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(paymentproviderinstance.EnabledEQ(true)).
+ Order(paymentproviderinstance.BySortOrder()).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("query enabled payment providers: %w", err)
+ }
+
+ matching := filterEnabledVisibleMethodInstances(instances, method)
+ providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching)
+ if err != nil {
+ return nil, err
+ }
+ if providerKey == "" {
+ if len(matching) == 0 {
+ return nil, nil
+ }
+ return &dbent.PaymentProviderInstance{ProviderKey: ""}, nil
+ }
+ return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil
+}
diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go
new file mode 100644
index 00000000..f2da40d9
--- /dev/null
+++ b/backend/internal/service/payment_webhook_provider.go
@@ -0,0 +1,148 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+// GetWebhookProvider returns the provider instance that should verify a webhook.
+// It resolves the original provider instance from the order whenever possible and
+// only falls back to a registry provider for legacy/single-instance scenarios.
+func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) {
+ providers, err := s.GetWebhookProviders(ctx, providerKey, outTradeNo)
+ if err != nil {
+ return nil, err
+ }
+ if len(providers) == 0 {
+ return nil, payment.ErrProviderNotFound
+ }
+ return providers[0], nil
+}
+
+// GetWebhookProviders returns provider candidates that can verify the webhook.
+// Official WeChat Pay may require multiple candidates because the callback body
+// cannot be bound to a merchant before decryption.
+func (s *PaymentService) GetWebhookProviders(ctx context.Context, providerKey, outTradeNo string) ([]payment.Provider, error) {
+ if outTradeNo != "" {
+ order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx)
+ if err == nil {
+ if psHasPinnedProviderInstance(order) {
+ prov, err := s.getPinnedOrderProvider(ctx, order)
+ if err != nil {
+ return nil, err
+ }
+ return []payment.Provider{prov}, nil
+ }
+ inst, err := s.getOrderProviderInstance(ctx, order)
+ if err != nil {
+ return nil, fmt.Errorf("load order provider instance: %w", err)
+ }
+ if inst != nil {
+ prov, err := s.createProviderFromInstance(ctx, inst)
+ if err != nil {
+ return nil, err
+ }
+ return []payment.Provider{prov}, nil
+ }
+ if strings.TrimSpace(providerKey) == payment.TypeWxpay {
+ return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
+ }
+ if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
+ return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
+ }
+ s.EnsureProviders(ctx)
+ prov, err := s.registry.GetProviderByKey(providerKey)
+ if err != nil {
+ return nil, err
+ }
+ return []payment.Provider{prov}, nil
+ }
+ }
+
+ if strings.TrimSpace(providerKey) == payment.TypeWxpay {
+ return s.getEnabledWebhookProvidersByKey(ctx, providerKey)
+ }
+
+ if !s.webhookRegistryFallbackAllowed(ctx, providerKey) {
+ return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey)
+ }
+
+ s.EnsureProviders(ctx)
+ prov, err := s.registry.GetProviderByKey(providerKey)
+ if err != nil {
+ return nil, err
+ }
+ return []payment.Provider{prov}, nil
+}
+
+func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
+ inst, err := s.getOrderProviderInstance(ctx, o)
+ if err != nil {
+ return nil, fmt.Errorf("load order provider instance: %w", err)
+ }
+ if inst == nil {
+ return nil, fmt.Errorf("order %d provider instance is missing", o.ID)
+ }
+ return s.createProviderFromInstance(ctx, inst)
+}
+
+func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool {
+ providerKey = strings.TrimSpace(providerKey)
+ if providerKey == "" || s == nil || s.entClient == nil {
+ return false
+ }
+
+ count, err := s.entClient.PaymentProviderInstance.Query().
+ Where(
+ paymentproviderinstance.ProviderKeyEQ(providerKey),
+ paymentproviderinstance.EnabledEQ(true),
+ ).
+ Count(ctx)
+ if err != nil {
+ slog.Warn("payment webhook fallback instance count failed", "provider", providerKey, "error", err)
+ return false
+ }
+ return count <= 1
+}
+
+func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool {
+ return order != nil && (psOrderProviderSnapshot(order) != nil || (order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != ""))
+}
+
+func (s *PaymentService) getEnabledWebhookProvidersByKey(ctx context.Context, providerKey string) ([]payment.Provider, error) {
+ providerKey = strings.TrimSpace(providerKey)
+ instances, err := s.entClient.PaymentProviderInstance.Query().
+ Where(
+ paymentproviderinstance.ProviderKeyEQ(providerKey),
+ paymentproviderinstance.EnabledEQ(true),
+ ).
+ Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("query webhook provider instances: %w", err)
+ }
+ if len(instances) == 0 {
+ return nil, payment.ErrProviderNotFound
+ }
+
+ providers := make([]payment.Provider, 0, len(instances))
+ for _, inst := range instances {
+ prov, provErr := s.createProviderFromInstance(ctx, inst)
+ if provErr != nil {
+ slog.Warn("skip webhook provider instance", "provider", providerKey, "instanceID", inst.ID, "error", provErr)
+ continue
+ }
+ providers = append(providers, prov)
+ }
+ if len(providers) == 0 {
+ return nil, payment.ErrProviderNotFound
+ }
+ return providers, nil
+}
diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go
new file mode 100644
index 00000000..0f3efa1f
--- /dev/null
+++ b/backend/internal/service/payment_webhook_provider_test.go
@@ -0,0 +1,510 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/json"
+ "encoding/pem"
+ "strconv"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/stretchr/testify/require"
+)
+
+const webhookProviderTestEncryptionKey = "0123456789abcdef0123456789abcdef"
+
+type webhookProviderTestDouble struct {
+ key string
+ types []payment.PaymentType
+}
+
+func (p webhookProviderTestDouble) Name() string { return p.key }
+func (p webhookProviderTestDouble) ProviderKey() string { return p.key }
+func (p webhookProviderTestDouble) SupportedTypes() []payment.PaymentType { return p.types }
+func (p webhookProviderTestDouble) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookProviderTestDouble) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookProviderTestDouble) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ panic("unexpected call")
+}
+func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
+
+func encryptWebhookProviderConfig(t *testing.T, config map[string]string) string {
+ t.Helper()
+
+ data, err := json.Marshal(config)
+ require.NoError(t, err)
+
+ encrypted, err := payment.Encrypt(string(data), []byte(webhookProviderTestEncryptionKey))
+ require.NoError(t, err)
+ return encrypted
+}
+
+func newWebhookProviderTestLoadBalancer(client *dbent.Client) payment.LoadBalancer {
+ return payment.NewDefaultLoadBalancer(client, []byte(webhookProviderTestEncryptionKey))
+}
+
+func encryptValidWebhookWxpayConfig(t *testing.T, suffix string) string {
+ t.Helper()
+
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ privDER, err := x509.MarshalPKCS8PrivateKey(key)
+ require.NoError(t, err)
+ pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
+ require.NoError(t, err)
+
+ return encryptWebhookProviderConfig(t, map[string]string{
+ "appId": "wx-app-" + suffix,
+ "mchId": "mch-" + suffix,
+ "privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
+ "apiV3Key": webhookProviderTestEncryptionKey,
+ "publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})),
+ "publicKeyId": "public-key-id-" + suffix,
+ "certSerial": "cert-serial-" + suffix,
+ })
+}
+
+func TestGetOrderProviderInstanceResolvesUniqueLegacyProviderKey(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-a").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_test_legacy_provider_key"})).
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ providerKey := payment.TypeStripe
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeStripe,
+ ProviderKey: &providerKey,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ require.Equal(t, inst.ID, got.ID)
+}
+
+func TestGetOrderProviderInstanceResolvesUniqueLegacyPaymentType(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-a").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpayDirect,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ require.Equal(t, inst.ID, got.ID)
+}
+
+func TestGetOrderProviderInstanceLeavesAmbiguousLegacyOrderUnresolved(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeEasyPay).
+ SetName("easypay-a").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-a").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeWxpay,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.Nil(t, got)
+}
+
+func TestGetOrderProviderInstanceLeavesLegacyProviderKeyUnresolvedWhenHistoricalInstancesConflict(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-disabled-legacy").
+ SetConfig("{}").
+ SetSupportedTypes("stripe").
+ SetEnabled(false).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-enabled-current").
+ SetConfig("{}").
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ providerKey := payment.TypeStripe
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeStripe,
+ ProviderKey: &providerKey,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.Nil(t, got)
+}
+
+func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupported(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-only").
+ SetConfig("{}").
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ providerKey := payment.TypeWxpay
+ order := &dbent.PaymentOrder{
+ PaymentType: payment.TypeAlipayDirect,
+ ProviderKey: &providerKey,
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.Nil(t, got)
+}
+
+func TestGetOrderProviderInstanceUsesProviderSnapshotWhenPinnedColumnMissing(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ inst, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-snapshot").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_snapshot"})).
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ ID: 42,
+ PaymentType: payment.TypeStripe,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": strconv.FormatInt(inst.ID, 10),
+ "provider_key": payment.TypeStripe,
+ },
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ require.Equal(t, inst.ID, got.ID)
+}
+
+func TestGetOrderProviderInstanceRejectsMissingSnapshotInstanceWithoutLegacyFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-legacy-fallback").
+ SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_legacy"})).
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ order := &dbent.PaymentOrder{
+ ID: 43,
+ PaymentType: payment.TypeStripe,
+ ProviderSnapshot: map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": "999999",
+ "provider_key": payment.TypeStripe,
+ },
+ }
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ }
+
+ got, err := svc.getOrderProviderInstance(ctx, order)
+ require.Nil(t, got)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "provider snapshot instance 999999 is missing")
+}
+
+func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ wxpayConfigA := encryptValidWebhookWxpayConfig(t, "a")
+ wxpayConfigB := encryptValidWebhookWxpayConfig(t, "b")
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-a").
+ SetConfig(wxpayConfigA).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-b").
+ SetConfig(wxpayConfigB).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ registry: payment.NewRegistry(),
+ providersLoaded: true,
+ }
+
+ providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "")
+ require.NoError(t, err)
+ require.Len(t, providers, 2)
+}
+
+func TestGetWebhookProvidersRejectAmbiguousFallbackForNonWxpay(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-a").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeAlipay).
+ SetName("alipay-b").
+ SetConfig("{}").
+ SetSupportedTypes("alipay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: payment.NewRegistry(),
+ providersLoaded: true,
+ }
+
+ _, err = svc.GetWebhookProviders(ctx, payment.TypeAlipay, "")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "ambiguous")
+}
+
+func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ _, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeStripe).
+ SetName("stripe-a").
+ SetConfig("{}").
+ SetSupportedTypes("stripe").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ registry.Register(webhookProviderTestDouble{
+ key: payment.TypeStripe,
+ types: []payment.PaymentType{payment.TypeStripe},
+ })
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ providersLoaded: true,
+ }
+
+ providers, err := svc.GetWebhookProviders(ctx, payment.TypeStripe, "")
+ require.NoError(t, err)
+ require.Len(t, providers, 1)
+ prov := providers[0]
+ require.Equal(t, payment.TypeStripe, prov.ProviderKey())
+}
+
+func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("webhook@example.com").
+ SetPasswordHash("hash").
+ SetUsername("webhook").
+ Save(ctx)
+ require.NoError(t, err)
+
+ pinnedInstanceID := "999"
+ _, err = client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("TEST-RECHARGE").
+ SetOutTradeNo("sub2_test_pinned_order").
+ SetPaymentType(payment.TypeWxpay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderInstanceID(pinnedInstanceID).
+ Save(ctx)
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ registry.Register(webhookProviderTestDouble{
+ key: payment.TypeWxpay,
+ types: []payment.PaymentType{payment.TypeWxpay},
+ })
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ providersLoaded: true,
+ }
+
+ _, err = svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_pinned_order")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "provider instance")
+}
+
+func TestGetWebhookProviderUsesProviderSnapshotBeforeWxpayFallback(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentConfigServiceTestClient(t)
+ user, err := client.User.Create().
+ SetEmail("snapshot-webhook@example.com").
+ SetPasswordHash("hash").
+ SetUsername("snapshot-webhook").
+ Save(ctx)
+ require.NoError(t, err)
+
+ wxpayConfigA := encryptValidWebhookWxpayConfig(t, "snapshot-a")
+ wxpayConfigB := encryptValidWebhookWxpayConfig(t, "snapshot-b")
+ instA, err := client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-snapshot-a").
+ SetConfig(wxpayConfigA).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.PaymentProviderInstance.Create().
+ SetProviderKey(payment.TypeWxpay).
+ SetName("wxpay-snapshot-b").
+ SetConfig(wxpayConfigB).
+ SetSupportedTypes("wxpay").
+ SetEnabled(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(66).
+ SetPayAmount(66).
+ SetFeeRate(0).
+ SetRechargeCode("SNAPSHOT-WEBHOOK").
+ SetOutTradeNo("sub2_test_snapshot_webhook_order").
+ SetPaymentType(payment.TypeWxpay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ SetProviderSnapshot(map[string]any{
+ "schema_version": 1,
+ "provider_instance_id": strconv.FormatInt(instA.ID, 10),
+ "provider_key": payment.TypeWxpay,
+ "payment_mode": "native",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &PaymentService{
+ entClient: client,
+ loadBalancer: newWebhookProviderTestLoadBalancer(client),
+ registry: payment.NewRegistry(),
+ providersLoaded: true,
+ }
+
+ providers, err := svc.GetWebhookProviders(ctx, payment.TypeWxpay, "sub2_test_snapshot_webhook_order")
+ require.NoError(t, err)
+ require.Len(t, providers, 1)
+ require.Equal(t, payment.TypeWxpay, providers[0].ProviderKey())
+}
diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go
index 2bf48702..91a02901 100644
--- a/backend/internal/service/pricing_service.go
+++ b/backend/internal/service/pricing_service.go
@@ -794,6 +794,13 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
}
}
+ // GPT-5.5 回退到 GPT-5.4 定价
+ if strings.HasPrefix(model, "gpt-5.5") {
+ logger.With(zap.String("component", "service.pricing")).
+ Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4(static)"))
+ return openAIGPT54FallbackPricing
+ }
+
if strings.HasPrefix(model, "gpt-5.4-mini") {
logger.With(zap.String("component", "service.pricing")).
Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4-mini(static)"))
@@ -812,6 +819,16 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
return openAIGPT54FallbackPricing
}
+ if isOpenAIImageGenerationModel(model) {
+ for _, candidate := range []string{"gpt-image-2", "gpt-image-1.5", "gpt-image-1"} {
+ if pricing, ok := s.pricingData[candidate]; ok {
+ logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI image fallback matched %s -> %s", model, candidate)
+ return pricing
+ }
+ }
+ return nil
+ }
+
// 最终回退到 DefaultTestModel
defaultModel := strings.ToLower(openai.DefaultTestModel)
if pricing, ok := s.pricingData[defaultModel]; ok {
diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go
index 13a5c70c..e2bd7cf3 100644
--- a/backend/internal/service/pricing_service_test.go
+++ b/backend/internal/service/pricing_service_test.go
@@ -128,6 +128,21 @@ func TestGetModelPricing_Gpt54NanoUsesDedicatedStaticFallbackWhenRemoteMissing(t
require.Zero(t, got.LongContextInputTokenThreshold)
}
+func TestGetModelPricing_ImageModelDoesNotFallbackToTextModel(t *testing.T) {
+ imagePricing := &LiteLLMModelPricing{InputCostPerToken: 3}
+ textPricing := &LiteLLMModelPricing{InputCostPerToken: 9}
+
+ svc := &PricingService{
+ pricingData: map[string]*LiteLLMModelPricing{
+ "gpt-image-2": imagePricing,
+ "gpt-5.4": textPricing,
+ },
+ }
+
+ got := svc.GetModelPricing("gpt-image-3")
+ require.Same(t, imagePricing, got)
+}
+
func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) {
raw := map[string]any{
"gpt-5.4": map[string]any{
diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go
index 53581574..9344de47 100644
--- a/backend/internal/service/ratelimit_service.go
+++ b/backend/internal/service/ratelimit_service.go
@@ -1,8 +1,10 @@
package service
import (
+ "bytes"
"context"
"encoding/json"
+ "fmt"
"log/slog"
"net/http"
"strconv"
@@ -23,6 +25,7 @@ type RateLimitService struct {
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
timeoutCounterCache TimeoutCounterCache
+ openAI403CounterCache OpenAI403CounterCache
settingService *SettingService
tokenCacheInvalidator TokenCacheInvalidator
usageCacheMu sync.RWMutex
@@ -52,6 +55,12 @@ type geminiUsageTotalsBatchProvider interface {
const geminiPrecheckCacheTTL = time.Minute
+const (
+ openAI403CooldownMinutesDefault = 10
+ openAI403DisableThreshold = 3
+ openAI403CounterWindowMinutes = 180
+)
+
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService {
return &RateLimitService{
@@ -69,6 +78,11 @@ func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) {
s.timeoutCounterCache = cache
}
+// SetOpenAI403CounterCache 设置 OpenAI 403 连续失败计数器(可选依赖)
+func (s *RateLimitService) SetOpenAI403CounterCache(cache OpenAI403CounterCache) {
+ s.openAI403CounterCache = cache
+}
+
// SetSettingService 设置系统设置服务(可选依赖)
func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
@@ -655,6 +669,30 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
}
+func buildForbiddenErrorMessage(prefix string, upstreamMsg string, responseBody []byte, fallback string) string {
+ prefix = strings.TrimSpace(prefix)
+ if prefix != "" && !strings.HasSuffix(prefix, " ") {
+ prefix += " "
+ }
+
+ if msg := strings.TrimSpace(upstreamMsg); msg != "" {
+ return prefix + msg
+ }
+
+ rawBody := bytes.TrimSpace(responseBody)
+ if len(rawBody) > 0 {
+ if json.Valid(rawBody) {
+ var compact bytes.Buffer
+ if err := json.Compact(&compact, rawBody); err == nil {
+ return prefix + truncateForLog(compact.Bytes(), 512)
+ }
+ }
+ return prefix + truncateForLog(rawBody, 512)
+ }
+
+ return prefix + fallback
+}
+
// handle403 处理 403 Forbidden 错误
// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用;
// 其他平台保持原有 SetError 行为。
@@ -662,15 +700,64 @@ func (s *RateLimitService) handle403(ctx context.Context, account *Account, upst
if account.Platform == PlatformAntigravity {
return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody)
}
- // 非 Antigravity 平台:保持原有行为
- msg := "Access forbidden (403): account may be suspended or lack permissions"
- if upstreamMsg != "" {
- msg = "Access forbidden (403): " + upstreamMsg
+ if account.Platform == PlatformOpenAI {
+ return s.handleOpenAI403(ctx, account, upstreamMsg, responseBody)
}
+ // 非 Antigravity 平台:保持原有行为
+ msg := buildForbiddenErrorMessage(
+ "Access forbidden (403):",
+ upstreamMsg,
+ responseBody,
+ "account may be suspended or lack permissions",
+ )
s.handleAuthError(ctx, account, msg)
return true
}
+func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
+ msg := buildForbiddenErrorMessage(
+ "Access forbidden (403):",
+ upstreamMsg,
+ responseBody,
+ "account may be suspended or lack permissions",
+ )
+
+ if s.openAI403CounterCache == nil {
+ s.handleAuthError(ctx, account, msg)
+ return true
+ }
+
+ count, err := s.openAI403CounterCache.IncrementOpenAI403Count(ctx, account.ID, openAI403CounterWindowMinutes)
+ if err != nil {
+ slog.Warn("openai_403_increment_failed", "account_id", account.ID, "error", err)
+ s.handleAuthError(ctx, account, msg)
+ return true
+ }
+
+ if count >= openAI403DisableThreshold {
+ msg = fmt.Sprintf("%s | consecutive_403=%d/%d", msg, count, openAI403DisableThreshold)
+ s.handleAuthError(ctx, account, msg)
+ return true
+ }
+
+ until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute)
+ reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg)
+ if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
+ slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
+ s.handleAuthError(ctx, account, msg)
+ return true
+ }
+
+ slog.Warn(
+ "openai_403_temp_unschedulable",
+ "account_id", account.ID,
+ "until", until,
+ "count", count,
+ "threshold", openAI403DisableThreshold,
+ )
+ return true
+}
+
// handleAntigravity403 处理 Antigravity 平台的 403 错误
// validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复)
// violation(违规封号)→ 永久 SetError(需人工处理)
@@ -681,10 +768,12 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
switch fbType {
case forbiddenTypeValidation:
// VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复
- msg := "Validation required (403): account needs Google verification"
- if upstreamMsg != "" {
- msg = "Validation required (403): " + upstreamMsg
- }
+ msg := buildForbiddenErrorMessage(
+ "Validation required (403):",
+ upstreamMsg,
+ responseBody,
+ "account needs Google verification",
+ )
if validationURL := extractValidationURL(string(responseBody)); validationURL != "" {
msg += " | validation_url: " + validationURL
}
@@ -693,19 +782,23 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
case forbiddenTypeViolation:
// 违规封号: 永久禁用,需人工处理
- msg := "Account violation (403): terms of service violation"
- if upstreamMsg != "" {
- msg = "Account violation (403): " + upstreamMsg
- }
+ msg := buildForbiddenErrorMessage(
+ "Account violation (403):",
+ upstreamMsg,
+ responseBody,
+ "terms of service violation",
+ )
s.handleAuthError(ctx, account, msg)
return true
default:
// 通用 403: 保持原有行为
- msg := "Access forbidden (403): account may be suspended or lack permissions"
- if upstreamMsg != "" {
- msg = "Access forbidden (403): " + upstreamMsg
- }
+ msg := buildForbiddenErrorMessage(
+ "Access forbidden (403):",
+ upstreamMsg,
+ responseBody,
+ "account may be suspended or lack permissions",
+ )
s.handleAuthError(ctx, account, msg)
return true
}
@@ -838,7 +931,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
// 返回 nil 表示无法从响应头中确定重置时间
-func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
+func calculateOpenAI429ResetTime(headers http.Header) *time.Time {
snapshot := ParseCodexRateLimitHeaders(headers)
if snapshot == nil {
return nil
@@ -884,6 +977,10 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
return nil
}
+func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
+ return calculateOpenAI429ResetTime(headers)
+}
+
// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
type anthropic429Result struct {
resetAt time.Time // The correct reset time to use for SetRateLimited
@@ -1221,9 +1318,19 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
}
}
+ s.ResetOpenAI403Counter(ctx, accountID)
return nil
}
+func (s *RateLimitService) ResetOpenAI403Counter(ctx context.Context, accountID int64) {
+ if s == nil || s.openAI403CounterCache == nil || accountID <= 0 {
+ return
+ }
+ if err := s.openAI403CounterCache.ResetOpenAI403Count(ctx, accountID); err != nil {
+ slog.Warn("openai_403_reset_failed", "account_id", accountID, "error", err)
+ }
+}
+
// RecoverAccountState 按需恢复账号的可恢复运行时状态。
func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) {
account, err := s.accountRepo.GetByID(ctx, accountID)
@@ -1250,6 +1357,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in
}
result.ClearedRateLimit = true
}
+ if result.ClearedError || result.ClearedRateLimit {
+ s.ResetOpenAI403Counter(ctx, accountID)
+ }
return result, nil
}
diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go
index 9e5e2b0e..73b7849f 100644
--- a/backend/internal/service/ratelimit_service_401_test.go
+++ b/backend/internal/service/ratelimit_service_401_test.go
@@ -20,6 +20,7 @@ type rateLimitAccountRepoStub struct {
updateCredentialsCalls int
lastCredentials map[string]any
lastErrorMsg string
+ lastTempReason string
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
@@ -30,6 +31,7 @@ func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, error
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
+ r.lastTempReason = reason
return nil
}
@@ -44,6 +46,29 @@ type tokenCacheInvalidatorRecorder struct {
err error
}
+type openAI403CounterCacheStub struct {
+ counts []int64
+ resetCalls []int64
+ err error
+}
+
+func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) {
+ if s.err != nil {
+ return 0, s.err
+ }
+ if len(s.counts) == 0 {
+ return 1, nil
+ }
+ count := s.counts[0]
+ s.counts = s.counts[1:]
+ return count, nil
+}
+
+func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
+ s.resetCalls = append(s.resetCalls, accountID)
+ return nil
+}
+
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
r.accounts = append(r.accounts, account)
return r.err
diff --git a/backend/internal/service/ratelimit_service_403_test.go b/backend/internal/service/ratelimit_service_403_test.go
new file mode 100644
index 00000000..2fd11b71
--- /dev/null
+++ b/backend/internal/service/ratelimit_service_403_test.go
@@ -0,0 +1,64 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "net/http"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) {
+ repo := &rateLimitAccountRepoStub{}
+ counter := &openAI403CounterCacheStub{counts: []int64{1}}
+ service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
+ service.SetOpenAI403CounterCache(counter)
+ account := &Account{
+ ID: 301,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ }
+
+ shouldDisable := service.HandleUpstreamError(
+ context.Background(),
+ account,
+ http.StatusForbidden,
+ http.Header{},
+ []byte(`{"error":{"message":"temporary edge rejection"}}`),
+ )
+
+ require.True(t, shouldDisable)
+ require.Equal(t, 0, repo.setErrorCalls)
+ require.Equal(t, 1, repo.tempCalls)
+ require.Contains(t, repo.lastTempReason, "temporary edge rejection")
+ require.Contains(t, repo.lastTempReason, "(1/3)")
+}
+
+func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) {
+ repo := &rateLimitAccountRepoStub{}
+ counter := &openAI403CounterCacheStub{counts: []int64{3}}
+ service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
+ service.SetOpenAI403CounterCache(counter)
+ account := &Account{
+ ID: 302,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ }
+
+ shouldDisable := service.HandleUpstreamError(
+ context.Background(),
+ account,
+ http.StatusForbidden,
+ http.Header{},
+ []byte(`{"error":{"message":"workspace forbidden by policy"}}`),
+ )
+
+ require.True(t, shouldDisable)
+ require.Equal(t, 1, repo.setErrorCalls)
+ require.Equal(t, 0, repo.tempCalls)
+ require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy")
+ require.Contains(t, repo.lastErrorMsg, "consecutive_403=3/3")
+}
diff --git a/backend/internal/service/ratelimit_service_openai_test.go b/backend/internal/service/ratelimit_service_openai_test.go
index 89c754c8..619bb773 100644
--- a/backend/internal/service/ratelimit_service_openai_test.go
+++ b/backend/internal/service/ratelimit_service_openai_test.go
@@ -7,6 +7,9 @@ import (
"net/http"
"testing"
"time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
)
func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) {
@@ -259,6 +262,53 @@ func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) {
}
}
+func TestRateLimitService_HandleUpstreamError_403PreservesOriginalUpstreamMessage(t *testing.T) {
+ repo := &rateLimitAccountRepoStub{}
+ service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
+ account := &Account{
+ ID: 201,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ }
+
+ shouldDisable := service.HandleUpstreamError(
+ context.Background(),
+ account,
+ 403,
+ http.Header{},
+ []byte(`{"error":{"message":"workspace forbidden by policy","type":"invalid_request_error"}}`),
+ )
+
+ require.True(t, shouldDisable)
+ require.Equal(t, 1, repo.setErrorCalls)
+ require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy")
+ require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions")
+}
+
+func TestRateLimitService_HandleUpstreamError_403FallsBackToRawBody(t *testing.T) {
+ repo := &rateLimitAccountRepoStub{}
+ service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
+ account := &Account{
+ ID: 202,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ }
+
+ shouldDisable := service.HandleUpstreamError(
+ context.Background(),
+ account,
+ 403,
+ http.Header{},
+ []byte(`{"error":{"type":"access_denied","details":{"reason":"ip_blocked"}}}`),
+ )
+
+ require.True(t, shouldDisable)
+ require.Equal(t, 1, repo.setErrorCalls)
+ require.Contains(t, repo.lastErrorMsg, `"access_denied"`)
+ require.Contains(t, repo.lastErrorMsg, `"ip_blocked"`)
+ require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions")
+}
+
func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) {
// Test when only secondary has data, no window_minutes
sUsed := 60.0
diff --git a/backend/internal/service/scheduler_cache.go b/backend/internal/service/scheduler_cache.go
index f36135e0..f9794c82 100644
--- a/backend/internal/service/scheduler_cache.go
+++ b/backend/internal/service/scheduler_cache.go
@@ -59,6 +59,8 @@ type SchedulerCache interface {
UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
// TryLockBucket 尝试获取分桶重建锁。
TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error)
+ // UnlockBucket 释放分桶重建锁。
+ UnlockBucket(ctx context.Context, bucket SchedulerBucket) error
// ListBuckets 返回已注册的分桶集合。
ListBuckets(ctx context.Context) ([]SchedulerBucket, error)
// GetOutboxWatermark 读取 outbox 水位。
diff --git a/backend/internal/service/scheduler_snapshot_hydration_test.go b/backend/internal/service/scheduler_snapshot_hydration_test.go
index 5c0b289b..0b32c2ad 100644
--- a/backend/internal/service/scheduler_snapshot_hydration_test.go
+++ b/backend/internal/service/scheduler_snapshot_hydration_test.go
@@ -44,6 +44,10 @@ func (c *snapshotHydrationCache) TryLockBucket(ctx context.Context, bucket Sched
return true, nil
}
+func (c *snapshotHydrationCache) UnlockBucket(ctx context.Context, bucket SchedulerBucket) error {
+ return nil
+}
+
func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) {
return nil, nil
}
diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go
index 62b6993d..a68cdf0c 100644
--- a/backend/internal/service/scheduler_snapshot_service.go
+++ b/backend/internal/service/scheduler_snapshot_service.go
@@ -544,6 +544,9 @@ func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket Sch
if !ok {
return nil
}
+ defer func() {
+ _ = s.cache.UnlockBucket(ctx, bucket)
+ }()
rebuildCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 7f4a2eb1..2bae686a 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"log/slog"
+ "math"
"net/url"
"sort"
"strconv"
@@ -81,10 +82,11 @@ const backendModeDBTimeout = 5 * time.Second
// cachedGatewayForwardingSettings 缓存网关转发行为设置(进程内缓存,60s TTL)
type cachedGatewayForwardingSettings struct {
- fingerprintUnification bool
- metadataPassthrough bool
- cchSigning bool
- expiresAt int64 // unix nano
+ fingerprintUnification bool
+ metadataPassthrough bool
+ cchSigning bool
+ anthropicCacheTTL1hInjection bool
+ expiresAt int64 // unix nano
}
var gatewayForwardingCache atomic.Value // *cachedGatewayForwardingSettings
@@ -114,6 +116,253 @@ type SettingService struct {
webSearchManagerBuilder WebSearchManagerBuilder
}
+type ProviderDefaultGrantSettings struct {
+ Balance float64
+ Concurrency int
+ Subscriptions []DefaultSubscriptionSetting
+ GrantOnSignup bool
+ GrantOnFirstBind bool
+}
+
+type AuthSourceDefaultSettings struct {
+ Email ProviderDefaultGrantSettings
+ LinuxDo ProviderDefaultGrantSettings
+ OIDC ProviderDefaultGrantSettings
+ WeChat ProviderDefaultGrantSettings
+ ForceEmailOnThirdPartySignup bool
+}
+
+type authSourceDefaultKeySet struct {
+ balance string
+ concurrency string
+ subscriptions string
+ grantOnSignup string
+ grantOnFirstBind string
+}
+
+var (
+ emailAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultEmailBalance,
+ concurrency: SettingKeyAuthSourceDefaultEmailConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
+ }
+ linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultLinuxDoBalance,
+ concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
+ }
+ oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultOIDCBalance,
+ concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
+ }
+ weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultWeChatBalance,
+ concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
+ }
+)
+
+const (
+ defaultAuthSourceBalance = 0
+ defaultAuthSourceConcurrency = 5
+ defaultWeChatConnectMode = "open"
+ defaultWeChatConnectScopes = "snsapi_login"
+ defaultWeChatConnectFrontend = "/auth/wechat/callback"
+)
+
+func normalizeWeChatConnectModeSetting(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "mp":
+ return "mp"
+ case "mobile":
+ return "mobile"
+ default:
+ return "open"
+ }
+}
+
+func defaultWeChatConnectScopeForMode(mode string) string {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ return "snsapi_userinfo"
+ case "mobile":
+ return ""
+ }
+ return defaultWeChatConnectScopes
+}
+
+func normalizeWeChatConnectScopeSetting(raw, mode string) string {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ switch strings.TrimSpace(raw) {
+ case "snsapi_base":
+ return "snsapi_base"
+ case "snsapi_userinfo":
+ return "snsapi_userinfo"
+ default:
+ return defaultWeChatConnectScopeForMode(mode)
+ }
+ case "mobile":
+ return ""
+ default:
+ return defaultWeChatConnectScopes
+ }
+}
+
+func parseWeChatConnectCapabilitySettings(settings map[string]string, enabled bool, mode string) (bool, bool, bool) {
+ mode = normalizeWeChatConnectModeSetting(mode)
+ rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled]
+ rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled]
+ rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled]
+ openConfigured := hasOpen && strings.TrimSpace(rawOpen) != ""
+ mpConfigured := hasMP && strings.TrimSpace(rawMP) != ""
+ mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != ""
+
+ if openConfigured || mpConfigured || mobileConfigured {
+ openEnabled := strings.TrimSpace(rawOpen) == "true"
+ mpEnabled := strings.TrimSpace(rawMP) == "true"
+ mobileEnabled := strings.TrimSpace(rawMobile) == "true"
+ return openEnabled, mpEnabled, mobileEnabled
+ }
+
+ if !enabled {
+ return false, false, false
+ }
+ if mode == "mp" {
+ return false, true, false
+ }
+ if mode == "mobile" {
+ return false, false, true
+ }
+ return true, false, false
+}
+
+func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string {
+ mode = normalizeWeChatConnectModeSetting(mode)
+ switch mode {
+ case "open":
+ if openEnabled {
+ return "open"
+ }
+ case "mp":
+ if mpEnabled {
+ return "mp"
+ }
+ case "mobile":
+ if mobileEnabled {
+ return "mobile"
+ }
+ }
+ switch {
+ case openEnabled:
+ return "open"
+ case mpEnabled:
+ return "mp"
+ case mobileEnabled:
+ return "mobile"
+ default:
+ return mode
+ }
+}
+
+func mergeWeChatConnectCapabilitySettings(settings map[string]string, base config.WeChatConnectConfig, enabled bool, mode string) (bool, bool, bool) {
+ mode = normalizeWeChatConnectModeSetting(firstNonEmpty(mode, base.Mode))
+ rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled]
+ rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled]
+ rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled]
+ openConfigured := hasOpen && strings.TrimSpace(rawOpen) != ""
+ mpConfigured := hasMP && strings.TrimSpace(rawMP) != ""
+ mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != ""
+
+ if openConfigured || mpConfigured || mobileConfigured {
+ openEnabled := strings.TrimSpace(rawOpen) == "true"
+ mpEnabled := strings.TrimSpace(rawMP) == "true"
+ mobileEnabled := strings.TrimSpace(rawMobile) == "true"
+ _, enabledConfigured := settings[SettingKeyWeChatConnectEnabled]
+ if !enabledConfigured &&
+ enabled &&
+ !openEnabled &&
+ !mpEnabled &&
+ !mobileEnabled &&
+ (base.OpenEnabled || base.MPEnabled || base.MobileEnabled) {
+ return base.OpenEnabled, base.MPEnabled, base.MobileEnabled
+ }
+ return openEnabled, mpEnabled, mobileEnabled
+ }
+ if !enabled {
+ return false, false, false
+ }
+ if base.OpenEnabled || base.MPEnabled || base.MobileEnabled {
+ return base.OpenEnabled, base.MPEnabled, base.MobileEnabled
+ }
+ return parseWeChatConnectCapabilitySettings(settings, enabled, mode)
+}
+
+func (s *SettingService) effectiveWeChatConnectOAuthConfig(settings map[string]string) WeChatConnectOAuthConfig {
+ base := config.WeChatConnectConfig{}
+ if s != nil && s.cfg != nil {
+ base = s.cfg.WeChat
+ }
+
+ enabled := base.Enabled
+ if raw, ok := settings[SettingKeyWeChatConnectEnabled]; ok {
+ enabled = strings.TrimSpace(raw) == "true"
+ }
+
+ legacyAppID := strings.TrimSpace(firstNonEmpty(
+ settings[SettingKeyWeChatConnectAppID],
+ base.AppID,
+ base.OpenAppID,
+ base.MPAppID,
+ base.MobileAppID,
+ ))
+ legacyAppSecret := strings.TrimSpace(firstNonEmpty(
+ settings[SettingKeyWeChatConnectAppSecret],
+ base.AppSecret,
+ base.OpenAppSecret,
+ base.MPAppSecret,
+ base.MobileAppSecret,
+ ))
+ openAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppID], base.OpenAppID, legacyAppID))
+ openAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectOpenAppSecret], base.OpenAppSecret, legacyAppSecret))
+ mpAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppID], base.MPAppID, legacyAppID))
+ mpAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMPAppSecret], base.MPAppSecret, legacyAppSecret))
+ mobileAppID := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppID], base.MobileAppID, legacyAppID))
+ mobileAppSecret := strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectMobileAppSecret], base.MobileAppSecret, legacyAppSecret))
+
+ modeRaw := firstNonEmpty(settings[SettingKeyWeChatConnectMode], base.Mode)
+ openEnabled, mpEnabled, mobileEnabled := mergeWeChatConnectCapabilitySettings(settings, base, enabled, modeRaw)
+ mode := normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled, modeRaw)
+
+ return WeChatConnectOAuthConfig{
+ Enabled: enabled,
+ LegacyAppID: legacyAppID,
+ LegacyAppSecret: legacyAppSecret,
+ OpenAppID: openAppID,
+ OpenAppSecret: openAppSecret,
+ MPAppID: mpAppID,
+ MPAppSecret: mpAppSecret,
+ MobileAppID: mobileAppID,
+ MobileAppSecret: mobileAppSecret,
+ OpenEnabled: openEnabled,
+ MPEnabled: mpEnabled,
+ MobileEnabled: mobileEnabled,
+ Mode: mode,
+ Scopes: normalizeWeChatConnectScopeSetting(firstNonEmpty(settings[SettingKeyWeChatConnectScopes], base.Scopes), mode),
+ RedirectURL: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectRedirectURL], base.RedirectURL)),
+ FrontendRedirectURL: strings.TrimSpace(firstNonEmpty(settings[SettingKeyWeChatConnectFrontendRedirectURL], base.FrontendRedirectURL, defaultWeChatConnectFrontend)),
+ }
+}
+
// NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{
@@ -156,6 +405,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
keys := []string{
SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled,
+ SettingKeyForceEmailOnThirdPartySignup,
SettingKeyRegistrationEmailSuffixWhitelist,
SettingKeyPromoCodeEnabled,
SettingKeyPasswordResetEnabled,
@@ -178,6 +428,22 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
+ SettingKeyWeChatConnectEnabled,
+ SettingKeyWeChatConnectAppID,
+ SettingKeyWeChatConnectAppSecret,
+ SettingKeyWeChatConnectOpenAppID,
+ SettingKeyWeChatConnectOpenAppSecret,
+ SettingKeyWeChatConnectMPAppID,
+ SettingKeyWeChatConnectMPAppSecret,
+ SettingKeyWeChatConnectMobileAppID,
+ SettingKeyWeChatConnectMobileAppSecret,
+ SettingKeyWeChatConnectOpenEnabled,
+ SettingKeyWeChatConnectMPEnabled,
+ SettingKeyWeChatConnectMobileEnabled,
+ SettingKeyWeChatConnectMode,
+ SettingKeyWeChatConnectScopes,
+ SettingKeyWeChatConnectRedirectURL,
+ SettingKeyWeChatConnectFrontendRedirectURL,
SettingKeyBackendModeEnabled,
SettingPaymentEnabled,
SettingKeyOIDCConnectEnabled,
@@ -186,6 +452,10 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyBalanceLowNotifyThreshold,
SettingKeyBalanceLowNotifyRechargeURL,
SettingKeyAccountQuotaNotifyEnabled,
+ SettingKeyChannelMonitorEnabled,
+ SettingKeyChannelMonitorDefaultIntervalSeconds,
+ SettingKeyAvailableChannelsEnabled,
+ SettingKeyAffiliateEnabled,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -212,6 +482,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
if oidcProviderName == "" {
oidcProviderName = "OIDC"
}
+ weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings)
// Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
@@ -232,6 +503,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled,
+ ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist,
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
PasswordResetEnabled: passwordResetEnabled,
@@ -254,6 +526,10 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
+ WeChatOAuthEnabled: weChatEnabled,
+ WeChatOAuthOpenEnabled: weChatOpenEnabled,
+ WeChatOAuthMPEnabled: weChatMPEnabled,
+ WeChatOAuthMobileEnabled: weChatMobileEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
@@ -262,9 +538,90 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true",
BalanceLowNotifyThreshold: balanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings[SettingKeyBalanceLowNotifyRechargeURL],
+
+ ChannelMonitorEnabled: !isFalseSettingValue(settings[SettingKeyChannelMonitorEnabled]),
+ ChannelMonitorDefaultIntervalSeconds: parseChannelMonitorInterval(settings[SettingKeyChannelMonitorDefaultIntervalSeconds]),
+
+ AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true",
+
+ AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true",
}, nil
}
+// channelMonitorIntervalMin / channelMonitorIntervalMax bound the default interval
+// (mirrors the monitor-level constraint but lives here so setting_service stays decoupled).
+const (
+ channelMonitorIntervalMin = 15
+ channelMonitorIntervalMax = 3600
+ channelMonitorIntervalFallback = 60
+)
+
+// parseChannelMonitorInterval parses the stored string and clamps to [15, 3600].
+// Empty / invalid input falls back to channelMonitorIntervalFallback.
+func parseChannelMonitorInterval(raw string) int {
+ v, err := strconv.Atoi(strings.TrimSpace(raw))
+ if err != nil {
+ return channelMonitorIntervalFallback
+ }
+ return clampChannelMonitorInterval(v)
+}
+
+// clampChannelMonitorInterval clamps v to the allowed range. 0 means "not provided".
+func clampChannelMonitorInterval(v int) int {
+ if v <= 0 {
+ return 0
+ }
+ if v < channelMonitorIntervalMin {
+ return channelMonitorIntervalMin
+ }
+ if v > channelMonitorIntervalMax {
+ return channelMonitorIntervalMax
+ }
+ return v
+}
+
+// ChannelMonitorRuntime is the lightweight view of the channel monitor feature
+// consumed by the runner and user-facing handlers.
+type ChannelMonitorRuntime struct {
+ Enabled bool
+ DefaultIntervalSeconds int
+}
+
+// GetChannelMonitorRuntime reads the channel monitor feature flags directly from
+// the settings store. Fail-open: on error returns Enabled=true with the default interval.
+func (s *SettingService) GetChannelMonitorRuntime(ctx context.Context) ChannelMonitorRuntime {
+ vals, err := s.settingRepo.GetMultiple(ctx, []string{
+ SettingKeyChannelMonitorEnabled,
+ SettingKeyChannelMonitorDefaultIntervalSeconds,
+ })
+ if err != nil {
+ return ChannelMonitorRuntime{Enabled: true, DefaultIntervalSeconds: channelMonitorIntervalFallback}
+ }
+ return ChannelMonitorRuntime{
+ Enabled: !isFalseSettingValue(vals[SettingKeyChannelMonitorEnabled]),
+ DefaultIntervalSeconds: parseChannelMonitorInterval(vals[SettingKeyChannelMonitorDefaultIntervalSeconds]),
+ }
+}
+
+// AvailableChannelsRuntime is the lightweight view of the available-channels feature
+// switch consumed by the user-facing handler.
+type AvailableChannelsRuntime struct {
+ Enabled bool
+}
+
+// GetAvailableChannelsRuntime reads the available-channels feature switch directly
+// from the settings store. Fail-closed: on error returns Enabled=false, matching
+// the opt-in default (unknown ↔ disabled).
+func (s *SettingService) GetAvailableChannelsRuntime(ctx context.Context) AvailableChannelsRuntime {
+ vals, err := s.settingRepo.GetMultiple(ctx, []string{SettingKeyAvailableChannelsEnabled})
+ if err != nil {
+ return AvailableChannelsRuntime{Enabled: false}
+ }
+ return AvailableChannelsRuntime{
+ Enabled: vals[SettingKeyAvailableChannelsEnabled] == "true",
+ }
+}
+
// SetOnUpdateCallback sets a callback function to be called when settings are updated
// This is used for cache invalidation (e.g., HTML cache in frontend server)
func (s *SettingService) SetOnUpdateCallback(callback func()) {
@@ -276,50 +633,76 @@ func (s *SettingService) SetVersion(version string) {
s.version = version
}
-// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection
-// This implements the web.PublicSettingsProvider interface
+// PublicSettingsInjectionPayload is the JSON shape embedded into HTML as
+// `window.__APP_CONFIG__` so the frontend can hydrate feature flags & site
+// config before the first XHR finishes.
+//
+// INVARIANT: every `json` tag here MUST also exist on handler/dto.PublicSettings.
+// If you forget a feature-flag field here, the frontend's
+// `cachedPublicSettings.xxx_enabled` will be `undefined` on refresh until the
+// async `/api/v1/settings/public` call returns — which causes opt-in menus
+// (strict `=== true`) to flicker off/on. See
+// frontend/src/utils/featureFlags.ts for the matching registry.
+//
+// A unit test diffs this struct's JSON keys against dto.PublicSettings to catch
+// drift automatically (see setting_service_injection_test.go).
+type PublicSettingsInjectionPayload struct {
+ RegistrationEnabled bool `json:"registration_enabled"`
+ EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
+ PromoCodeEnabled bool `json:"promo_code_enabled"`
+ PasswordResetEnabled bool `json:"password_reset_enabled"`
+ InvitationCodeEnabled bool `json:"invitation_code_enabled"`
+ TotpEnabled bool `json:"totp_enabled"`
+ TurnstileEnabled bool `json:"turnstile_enabled"`
+ TurnstileSiteKey string `json:"turnstile_site_key"`
+ SiteName string `json:"site_name"`
+ SiteLogo string `json:"site_logo"`
+ SiteSubtitle string `json:"site_subtitle"`
+ APIBaseURL string `json:"api_base_url"`
+ ContactInfo string `json:"contact_info"`
+ DocURL string `json:"doc_url"`
+ HomeContent string `json:"home_content"`
+ HideCcsImportButton bool `json:"hide_ccs_import_button"`
+ PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
+ PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
+ TableDefaultPageSize int `json:"table_default_page_size"`
+ TablePageSizeOptions []int `json:"table_page_size_options"`
+ CustomMenuItems json.RawMessage `json:"custom_menu_items"`
+ CustomEndpoints json.RawMessage `json:"custom_endpoints"`
+ LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
+ WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
+ WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
+ WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
+ OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
+ OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
+ BackendModeEnabled bool `json:"backend_mode_enabled"`
+ PaymentEnabled bool `json:"payment_enabled"`
+ Version string `json:"version"`
+ BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
+ AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
+ BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
+ BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
+
+ // Feature flags — MUST match the opt-in/opt-out registry in
+ // frontend/src/utils/featureFlags.ts. Missing a field here is the bug
+ // that hid the "可用渠道" menu on page refresh.
+ ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
+ AvailableChannelsEnabled bool `json:"available_channels_enabled"`
+ AffiliateEnabled bool `json:"affiliate_enabled"`
+}
+
+// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection.
+// This implements the web.PublicSettingsProvider interface.
func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any, error) {
settings, err := s.GetPublicSettings(ctx)
if err != nil {
return nil, err
}
- // Return a struct that matches the frontend's expected format
- return &struct {
- RegistrationEnabled bool `json:"registration_enabled"`
- EmailVerifyEnabled bool `json:"email_verify_enabled"`
- RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
- PromoCodeEnabled bool `json:"promo_code_enabled"`
- PasswordResetEnabled bool `json:"password_reset_enabled"`
- InvitationCodeEnabled bool `json:"invitation_code_enabled"`
- TotpEnabled bool `json:"totp_enabled"`
- TurnstileEnabled bool `json:"turnstile_enabled"`
- TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
- SiteName string `json:"site_name"`
- SiteLogo string `json:"site_logo,omitempty"`
- SiteSubtitle string `json:"site_subtitle,omitempty"`
- APIBaseURL string `json:"api_base_url,omitempty"`
- ContactInfo string `json:"contact_info,omitempty"`
- DocURL string `json:"doc_url,omitempty"`
- HomeContent string `json:"home_content,omitempty"`
- HideCcsImportButton bool `json:"hide_ccs_import_button"`
- PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
- PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
- TableDefaultPageSize int `json:"table_default_page_size"`
- TablePageSizeOptions []int `json:"table_page_size_options"`
- CustomMenuItems json.RawMessage `json:"custom_menu_items"`
- CustomEndpoints json.RawMessage `json:"custom_endpoints"`
- LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
- BackendModeEnabled bool `json:"backend_mode_enabled"`
- PaymentEnabled bool `json:"payment_enabled"`
- OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
- OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
- Version string `json:"version,omitempty"`
- BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
- AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
- BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
- BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
- }{
+ return &PublicSettingsInjectionPayload{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
@@ -344,18 +727,85 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
- BackendModeEnabled: settings.BackendModeEnabled,
- PaymentEnabled: settings.PaymentEnabled,
+ WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
+ WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,
+ WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled,
+ WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
+ BackendModeEnabled: settings.BackendModeEnabled,
+ PaymentEnabled: settings.PaymentEnabled,
Version: s.version,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
+
+ ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
+ ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
+ AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
+ AffiliateEnabled: settings.AffiliateEnabled,
}, nil
}
+func DefaultWeChatConnectScopesForMode(mode string) string {
+ return defaultWeChatConnectScopeForMode(mode)
+}
+
+func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]string) (WeChatConnectOAuthConfig, error) {
+ cfg := s.effectiveWeChatConnectOAuthConfig(settings)
+
+ if !cfg.Enabled || (!cfg.OpenEnabled && !cfg.MPEnabled) {
+ return WeChatConnectOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
+ }
+ if cfg.OpenEnabled {
+ if cfg.AppIDForMode("open") == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth pc app id not configured")
+ }
+ if cfg.AppSecretForMode("open") == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth pc app secret not configured")
+ }
+ }
+ if cfg.MPEnabled {
+ if cfg.AppIDForMode("mp") == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth official account app id not configured")
+ }
+ if cfg.AppSecretForMode("mp") == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth official account app secret not configured")
+ }
+ }
+ if cfg.MobileEnabled {
+ if cfg.AppIDForMode("mobile") == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app id not configured")
+ }
+ if cfg.AppSecretForMode("mobile") == "" {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app secret not configured")
+ }
+ }
+ if v := strings.TrimSpace(cfg.RedirectURL); v != "" {
+ if err := config.ValidateAbsoluteHTTPURL(v); err != nil {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid")
+ }
+ }
+ if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil {
+ return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url invalid")
+ }
+ return cfg, nil
+}
+
+func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string]string) (bool, bool, bool, bool) {
+ cfg := s.effectiveWeChatConnectOAuthConfig(settings)
+ if !cfg.Enabled {
+ return false, false, false, false
+ }
+
+ openReady := cfg.OpenEnabled && cfg.AppIDForMode("open") != "" && cfg.AppSecretForMode("open") != ""
+ mpReady := cfg.MPEnabled && cfg.AppIDForMode("mp") != "" && cfg.AppSecretForMode("mp") != ""
+ mobileReady := cfg.MobileEnabled && cfg.AppIDForMode("mobile") != "" && cfg.AppSecretForMode("mobile") != ""
+
+ return openReady || mpReady, openReady, mpReady, mobileReady
+}
+
// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON
// array string, returning only items with visibility != "admin".
func filterUserVisibleMenuItems(raw string) json.RawMessage {
@@ -478,19 +928,130 @@ func parseCustomMenuItemURLs(raw string) []string {
return urls
}
+func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool {
+ if base.UsePKCEExplicit {
+ return base.UsePKCE
+ }
+ return true
+}
+
+func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool {
+ if base.ValidateIDTokenExplicit {
+ return base.ValidateIDToken
+ }
+ return true
+}
+
+func oidcCompatibilityWriteDefault(base config.OIDCConnectConfig, configured bool, raw string, explicit bool, explicitValue bool) bool {
+ if configured {
+ return strings.TrimSpace(raw) == "true"
+ }
+ if explicit {
+ return explicitValue
+ }
+ return false
+}
+
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
- if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
+ updates, err := s.buildSystemSettingsUpdates(ctx, settings)
+ if err != nil {
return err
}
+
+ err = s.settingRepo.SetMultiple(ctx, updates)
+ if err == nil {
+ s.refreshCachedSettings(settings)
+ }
+ return err
+}
+
+func (s *SettingService) OIDCSecurityWriteDefaults(ctx context.Context) (bool, bool, error) {
+ rawSettings, err := s.settingRepo.GetMultiple(ctx, []string{
+ SettingKeyOIDCConnectUsePKCE,
+ SettingKeyOIDCConnectValidateIDToken,
+ })
+ if err != nil {
+ return false, false, fmt.Errorf("get oidc security write defaults: %w", err)
+ }
+
+ base := config.OIDCConnectConfig{}
+ if s != nil && s.cfg != nil {
+ base = s.cfg.OIDC
+ }
+
+ rawUsePKCE, hasUsePKCE := rawSettings[SettingKeyOIDCConnectUsePKCE]
+ rawValidateIDToken, hasValidateIDToken := rawSettings[SettingKeyOIDCConnectValidateIDToken]
+
+ return oidcCompatibilityWriteDefault(base, hasUsePKCE, rawUsePKCE, base.UsePKCEExplicit, base.UsePKCE),
+ oidcCompatibilityWriteDefault(base, hasValidateIDToken, rawValidateIDToken, base.ValidateIDTokenExplicit, base.ValidateIDToken),
+ nil
+}
+
+// UpdateSettingsWithAuthSourceDefaults persists system settings and auth-source defaults in a single write.
+func (s *SettingService) UpdateSettingsWithAuthSourceDefaults(ctx context.Context, settings *SystemSettings, authDefaults *AuthSourceDefaultSettings) error {
+ updates, err := s.buildSystemSettingsUpdates(ctx, settings)
+ if err != nil {
+ return err
+ }
+
+ authSourceUpdates, err := s.buildAuthSourceDefaultUpdates(ctx, authDefaults)
+ if err != nil {
+ return err
+ }
+ for key, value := range authSourceUpdates {
+ updates[key] = value
+ }
+
+ err = s.settingRepo.SetMultiple(ctx, updates)
+ if err == nil {
+ s.refreshCachedSettings(settings)
+ }
+ return err
+}
+
+func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, settings *SystemSettings) (map[string]string, error) {
+ if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
+ return nil, err
+ }
normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist)
if err != nil {
- return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error())
+ return nil, infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error())
}
if normalizedWhitelist == nil {
normalizedWhitelist = []string{}
}
settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist
+ alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled)
+ if err != nil {
+ return nil, err
+ }
+ wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled)
+ if err != nil {
+ return nil, err
+ }
+ settings.PaymentVisibleMethodAlipaySource = alipaySource
+ settings.PaymentVisibleMethodWxpaySource = wxpaySource
+ settings.WeChatConnectAppID = strings.TrimSpace(settings.WeChatConnectAppID)
+ settings.WeChatConnectAppSecret = strings.TrimSpace(settings.WeChatConnectAppSecret)
+ settings.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectOpenAppID, settings.WeChatConnectAppID))
+ settings.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectOpenAppSecret, settings.WeChatConnectAppSecret))
+ settings.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMPAppID, settings.WeChatConnectAppID))
+ settings.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMPAppSecret, settings.WeChatConnectAppSecret))
+ settings.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMobileAppID, settings.WeChatConnectAppID))
+ settings.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMobileAppSecret, settings.WeChatConnectAppSecret))
+ settings.WeChatConnectMode = normalizeWeChatConnectStoredMode(
+ settings.WeChatConnectOpenEnabled,
+ settings.WeChatConnectMPEnabled,
+ settings.WeChatConnectMobileEnabled,
+ settings.WeChatConnectMode,
+ )
+ settings.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings.WeChatConnectScopes, settings.WeChatConnectMode)
+ settings.WeChatConnectRedirectURL = strings.TrimSpace(settings.WeChatConnectRedirectURL)
+ settings.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings.WeChatConnectFrontendRedirectURL)
+ if settings.WeChatConnectFrontendRedirectURL == "" {
+ settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend
+ }
updates := make(map[string]string)
@@ -499,7 +1060,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist)
if err != nil {
- return fmt.Errorf("marshal registration email suffix whitelist: %w", err)
+ return nil, fmt.Errorf("marshal registration email suffix whitelist: %w", err)
}
updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON)
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
@@ -560,6 +1121,32 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret
}
+ // WeChat Connect OAuth 登录
+ updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled)
+ updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID
+ updates[SettingKeyWeChatConnectOpenAppID] = settings.WeChatConnectOpenAppID
+ updates[SettingKeyWeChatConnectMPAppID] = settings.WeChatConnectMPAppID
+ updates[SettingKeyWeChatConnectMobileAppID] = settings.WeChatConnectMobileAppID
+ updates[SettingKeyWeChatConnectOpenEnabled] = strconv.FormatBool(settings.WeChatConnectOpenEnabled)
+ updates[SettingKeyWeChatConnectMPEnabled] = strconv.FormatBool(settings.WeChatConnectMPEnabled)
+ updates[SettingKeyWeChatConnectMobileEnabled] = strconv.FormatBool(settings.WeChatConnectMobileEnabled)
+ updates[SettingKeyWeChatConnectMode] = settings.WeChatConnectMode
+ updates[SettingKeyWeChatConnectScopes] = settings.WeChatConnectScopes
+ updates[SettingKeyWeChatConnectRedirectURL] = settings.WeChatConnectRedirectURL
+ updates[SettingKeyWeChatConnectFrontendRedirectURL] = settings.WeChatConnectFrontendRedirectURL
+ if settings.WeChatConnectAppSecret != "" {
+ updates[SettingKeyWeChatConnectAppSecret] = settings.WeChatConnectAppSecret
+ }
+ if settings.WeChatConnectOpenAppSecret != "" {
+ updates[SettingKeyWeChatConnectOpenAppSecret] = settings.WeChatConnectOpenAppSecret
+ }
+ if settings.WeChatConnectMPAppSecret != "" {
+ updates[SettingKeyWeChatConnectMPAppSecret] = settings.WeChatConnectMPAppSecret
+ }
+ if settings.WeChatConnectMobileAppSecret != "" {
+ updates[SettingKeyWeChatConnectMobileAppSecret] = settings.WeChatConnectMobileAppSecret
+ }
+
// OEM设置
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
@@ -578,7 +1165,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize)
tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions)
if err != nil {
- return fmt.Errorf("marshal table page size options: %w", err)
+ return nil, fmt.Errorf("marshal table page size options: %w", err)
}
updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON)
updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
@@ -587,9 +1174,30 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
+ settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate)
+ updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64)
+ if settings.AffiliateRebateFreezeHours < 0 {
+ settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursDefault
+ }
+ if settings.AffiliateRebateFreezeHours > AffiliateRebateFreezeHoursMax {
+ settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursMax
+ }
+ updates[SettingKeyAffiliateRebateFreezeHours] = strconv.Itoa(settings.AffiliateRebateFreezeHours)
+ if settings.AffiliateRebateDurationDays < 0 {
+ settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysDefault
+ }
+ if settings.AffiliateRebateDurationDays > AffiliateRebateDurationDaysMax {
+ settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysMax
+ }
+ updates[SettingKeyAffiliateRebateDurationDays] = strconv.Itoa(settings.AffiliateRebateDurationDays)
+ if settings.AffiliateRebatePerInviteeCap < 0 {
+ settings.AffiliateRebatePerInviteeCap = AffiliateRebatePerInviteeCapDefault
+ }
+ updates[SettingKeyAffiliateRebatePerInviteeCap] = strconv.FormatFloat(settings.AffiliateRebatePerInviteeCap, 'f', 8, 64)
+ updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit)
defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
if err != nil {
- return fmt.Errorf("marshal default subscriptions: %w", err)
+ return nil, fmt.Errorf("marshal default subscriptions: %w", err)
}
updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON)
@@ -612,6 +1220,18 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds)
}
+ // Channel monitor feature switch
+ updates[SettingKeyChannelMonitorEnabled] = strconv.FormatBool(settings.ChannelMonitorEnabled)
+ if v := clampChannelMonitorInterval(settings.ChannelMonitorDefaultIntervalSeconds); v > 0 {
+ updates[SettingKeyChannelMonitorDefaultIntervalSeconds] = strconv.Itoa(v)
+ }
+
+ // Available channels feature switch
+ updates[SettingKeyAvailableChannelsEnabled] = strconv.FormatBool(settings.AvailableChannelsEnabled)
+
+ // Affiliate (邀请返利) feature switch
+ updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled)
+
// Claude Code version check
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion
@@ -626,6 +1246,12 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification)
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
+ updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection)
+ updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource
+ updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource
+ updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled)
+ updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled)
+ updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled)
// Balance low notification
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
@@ -634,32 +1260,67 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
- err = s.settingRepo.SetMultiple(ctx, updates)
- if err == nil {
- // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
- versionBoundsSF.Forget("version_bounds")
- versionBoundsCache.Store(&cachedVersionBounds{
- min: settings.MinClaudeCodeVersion,
- max: settings.MaxClaudeCodeVersion,
- expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(),
- })
- backendModeSF.Forget("backend_mode")
- backendModeCache.Store(&cachedBackendMode{
- value: settings.BackendModeEnabled,
- expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
- })
- gatewayForwardingSF.Forget("gateway_forwarding")
- gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
- fingerprintUnification: settings.EnableFingerprintUnification,
- metadataPassthrough: settings.EnableMetadataPassthrough,
- cchSigning: settings.EnableCCHSigning,
- expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
- })
- if s.onUpdate != nil {
- s.onUpdate() // Invalidate cache after settings update
+ return updates, nil
+}
+
+func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) {
+ if settings == nil {
+ return nil, nil
+ }
+
+ for _, subscriptions := range [][]DefaultSubscriptionSetting{
+ settings.Email.Subscriptions,
+ settings.LinuxDo.Subscriptions,
+ settings.OIDC.Subscriptions,
+ settings.WeChat.Subscriptions,
+ } {
+ if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
+ return nil, err
}
}
- return err
+
+ updates := make(map[string]string, 21)
+ writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
+ writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
+ writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
+ writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
+ updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
+ return updates, nil
+}
+
+func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
+ if settings == nil {
+ return
+ }
+
+ // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
+ versionBoundsSF.Forget("version_bounds")
+ versionBoundsCache.Store(&cachedVersionBounds{
+ min: settings.MinClaudeCodeVersion,
+ max: settings.MaxClaudeCodeVersion,
+ expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(),
+ })
+ backendModeSF.Forget("backend_mode")
+ backendModeCache.Store(&cachedBackendMode{
+ value: settings.BackendModeEnabled,
+ expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
+ })
+ gatewayForwardingSF.Forget("gateway_forwarding")
+ gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
+ fingerprintUnification: settings.EnableFingerprintUnification,
+ metadataPassthrough: settings.EnableMetadataPassthrough,
+ cchSigning: settings.EnableCCHSigning,
+ anthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection,
+ expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
+ })
+ openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey)
+ openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
+ enabled: settings.OpenAIAdvancedSchedulerEnabled,
+ expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
+ })
+ if s.onUpdate != nil {
+ s.onUpdate() // Invalidate cache after settings update
+ }
}
func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
@@ -757,22 +1418,30 @@ func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
return false
}
-// GetGatewayForwardingSettings returns cached gateway forwarding settings.
-// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path.
-// Returns (fingerprintUnification, metadataPassthrough, cchSigning).
-func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) {
+type gatewayForwardingSettingsResult struct {
+ fp, mp, cch, cacheTTL1h bool
+}
+
+func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context) gatewayForwardingSettingsResult {
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
- return cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning
+ return gatewayForwardingSettingsResult{
+ fp: cached.fingerprintUnification,
+ mp: cached.metadataPassthrough,
+ cch: cached.cchSigning,
+ cacheTTL1h: cached.anthropicCacheTTL1hInjection,
+ }
}
}
- type gwfResult struct {
- fp, mp, cch bool
- }
val, _, _ := gatewayForwardingSF.Do("gateway_forwarding", func() (any, error) {
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
- return gwfResult{cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning}, nil
+ return gatewayForwardingSettingsResult{
+ fp: cached.fingerprintUnification,
+ mp: cached.metadataPassthrough,
+ cch: cached.cchSigning,
+ cacheTTL1h: cached.anthropicCacheTTL1hInjection,
+ }, nil
}
}
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gatewayForwardingDBTimeout)
@@ -781,16 +1450,18 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing
SettingKeyEnableFingerprintUnification,
SettingKeyEnableMetadataPassthrough,
SettingKeyEnableCCHSigning,
+ SettingKeyEnableAnthropicCacheTTL1hInjection,
})
if err != nil {
slog.Warn("failed to get gateway forwarding settings", "error", err)
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
- fingerprintUnification: true,
- metadataPassthrough: false,
- cchSigning: false,
- expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
+ fingerprintUnification: true,
+ metadataPassthrough: false,
+ cchSigning: false,
+ anthropicCacheTTL1hInjection: false,
+ expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
})
- return gwfResult{true, false, false}, nil
+ return gatewayForwardingSettingsResult{fp: true}, nil
}
fp := true
if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" {
@@ -798,18 +1469,33 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing
}
mp := values[SettingKeyEnableMetadataPassthrough] == "true"
cch := values[SettingKeyEnableCCHSigning] == "true"
+ cacheTTL1h := values[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
- fingerprintUnification: fp,
- metadataPassthrough: mp,
- cchSigning: cch,
- expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
+ fingerprintUnification: fp,
+ metadataPassthrough: mp,
+ cchSigning: cch,
+ anthropicCacheTTL1hInjection: cacheTTL1h,
+ expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
})
- return gwfResult{fp, mp, cch}, nil
+ return gatewayForwardingSettingsResult{fp: fp, mp: mp, cch: cch, cacheTTL1h: cacheTTL1h}, nil
})
- if r, ok := val.(gwfResult); ok {
- return r.fp, r.mp, r.cch
+ if r, ok := val.(gatewayForwardingSettingsResult); ok {
+ return r
}
- return true, false, false // fail-open defaults
+ return gatewayForwardingSettingsResult{fp: true}
+}
+
+// GetGatewayForwardingSettings returns cached gateway forwarding settings.
+// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path.
+// Returns (fingerprintUnification, metadataPassthrough, cchSigning).
+func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) {
+ result := s.getGatewayForwardingSettingsCached(ctx)
+ return result.fp, result.mp, result.cch
+}
+
+// IsAnthropicCacheTTL1hInjectionEnabled 检查是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl。
+func (s *SettingService) IsAnthropicCacheTTL1hInjectionEnabled(ctx context.Context) bool {
+ return s.getGatewayForwardingSettingsCached(ctx).cacheTTL1h
}
// IsEmailVerifyEnabled 检查是否开启邮件验证
@@ -848,6 +1534,78 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool {
return value == "true"
}
+// IsAffiliateEnabled 检查是否启用邀请返利功能(总开关)
+func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool {
+ value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled)
+ if err != nil {
+ return false // 默认关闭
+ }
+ return value == "true"
+}
+
+// GetAffiliateRebateRatePercent 读取并 clamp 全局返利比例。
+// 解析失败、缺失或越界都回退到 AffiliateRebateRateDefault — 该比例从不抛错,
+// 调用方只关心一个可用的数值。
+func (s *SettingService) GetAffiliateRebateRatePercent(ctx context.Context) float64 {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateRate)
+ if err != nil {
+ return AffiliateRebateRateDefault
+ }
+ rate, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
+ if err != nil || math.IsNaN(rate) || math.IsInf(rate, 0) {
+ return AffiliateRebateRateDefault
+ }
+ return clampAffiliateRebateRate(rate)
+}
+
+// GetAffiliateRebateFreezeHours 返回返利冻结期(小时)。
+// 返回 0 表示不冻结(向后兼容)。
+func (s *SettingService) GetAffiliateRebateFreezeHours(ctx context.Context) int {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateFreezeHours)
+ if err != nil {
+ return AffiliateRebateFreezeHoursDefault
+ }
+ hours, err := strconv.Atoi(strings.TrimSpace(raw))
+ if err != nil || hours < 0 {
+ return AffiliateRebateFreezeHoursDefault
+ }
+ if hours > AffiliateRebateFreezeHoursMax {
+ return AffiliateRebateFreezeHoursMax
+ }
+ return hours
+}
+
+// GetAffiliateRebateDurationDays 返回返利有效期(天)。
+// 返回 0 表示永久有效。
+func (s *SettingService) GetAffiliateRebateDurationDays(ctx context.Context) int {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebateDurationDays)
+ if err != nil {
+ return AffiliateRebateDurationDaysDefault
+ }
+ days, err := strconv.Atoi(strings.TrimSpace(raw))
+ if err != nil || days < 0 {
+ return AffiliateRebateDurationDaysDefault
+ }
+ if days > AffiliateRebateDurationDaysMax {
+ return AffiliateRebateDurationDaysMax
+ }
+ return days
+}
+
+// GetAffiliateRebatePerInviteeCap 返回单人返利上限。
+// 返回 0 表示无上限。
+func (s *SettingService) GetAffiliateRebatePerInviteeCap(ctx context.Context) float64 {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateRebatePerInviteeCap)
+ if err != nil {
+ return AffiliateRebatePerInviteeCapDefault
+ }
+ cap, err := strconv.ParseFloat(strings.TrimSpace(raw), 64)
+ if err != nil || cap < 0 || math.IsNaN(cap) || math.IsInf(cap, 0) {
+ return AffiliateRebatePerInviteeCapDefault
+ }
+ return cap
+}
+
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
@@ -910,6 +1668,18 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
return s.cfg.Default.UserBalance
}
+// GetDefaultUserRPMLimit 获取新用户默认 RPM 限制(0 = 不限制)。未配置则返回 0。
+func (s *SettingService) GetDefaultUserRPMLimit(ctx context.Context) int {
+ value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultUserRPMLimit)
+ if err != nil || value == "" {
+ return 0
+ }
+ if v, err := strconv.Atoi(value); err == nil && v >= 0 {
+ return v
+ }
+ return 0
+}
+
// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
@@ -919,6 +1689,88 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS
return parseDefaultSubscriptions(value)
}
+func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) {
+ keys := []string{
+ SettingKeyAuthSourceDefaultEmailBalance,
+ SettingKeyAuthSourceDefaultEmailConcurrency,
+ SettingKeyAuthSourceDefaultEmailSubscriptions,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup,
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultLinuxDoBalance,
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency,
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultOIDCBalance,
+ SettingKeyAuthSourceDefaultOIDCConcurrency,
+ SettingKeyAuthSourceDefaultOIDCSubscriptions,
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultWeChatBalance,
+ SettingKeyAuthSourceDefaultWeChatConcurrency,
+ SettingKeyAuthSourceDefaultWeChatSubscriptions,
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
+ SettingKeyForceEmailOnThirdPartySignup,
+ }
+
+ settings, err := s.settingRepo.GetMultiple(ctx, keys)
+ if err != nil {
+ return nil, fmt.Errorf("get auth source default settings: %w", err)
+ }
+
+ return &AuthSourceDefaultSettings{
+ Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys),
+ LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
+ OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
+ WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
+ ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
+ }, nil
+}
+
+func (s *SettingService) ResolveAuthSourceGrantSettings(ctx context.Context, signupSource string, firstBind bool) (ProviderDefaultGrantSettings, bool, error) {
+ result := ProviderDefaultGrantSettings{
+ Balance: s.GetDefaultBalance(ctx),
+ Concurrency: s.GetDefaultConcurrency(ctx),
+ Subscriptions: s.GetDefaultSubscriptions(ctx),
+ }
+
+ defaults, err := s.GetAuthSourceDefaultSettings(ctx)
+ if err != nil {
+ return result, false, err
+ }
+
+ providerDefaults, ok := authSourceSignupSettings(defaults, signupSource)
+ if !ok {
+ return result, false, nil
+ }
+
+ enabled := providerDefaults.GrantOnSignup
+ if firstBind {
+ enabled = providerDefaults.GrantOnFirstBind
+ }
+ if !enabled {
+ return result, false, nil
+ }
+
+ return mergeProviderDefaultGrantSettings(result, providerDefaults), true, nil
+}
+
+func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error {
+ updates, err := s.buildAuthSourceDefaultUpdates(ctx, settings)
+ if err != nil {
+ return err
+ }
+ if len(updates) == 0 {
+ return nil
+ }
+
+ if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
+ return fmt.Errorf("update auth source default settings: %w", err)
+ }
+ return nil
+}
+
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
@@ -931,27 +1783,100 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
return fmt.Errorf("check existing settings: %w", err)
}
+ oidcUsePKCEDefault := true
+ oidcValidateIDTokenDefault := true
+ if s != nil && s.cfg != nil {
+ if s.cfg.OIDC.UsePKCEExplicit {
+ oidcUsePKCEDefault = s.cfg.OIDC.UsePKCE
+ }
+ if s.cfg.OIDC.ValidateIDTokenExplicit {
+ oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken
+ }
+ }
+
// 初始化默认设置
defaults := map[string]string{
- SettingKeyRegistrationEnabled: "true",
- SettingKeyEmailVerifyEnabled: "false",
- SettingKeyRegistrationEmailSuffixWhitelist: "[]",
- SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
- SettingKeySiteName: "Sub2API",
- SettingKeySiteLogo: "",
- SettingKeyPurchaseSubscriptionEnabled: "false",
- SettingKeyPurchaseSubscriptionURL: "",
- SettingKeyTableDefaultPageSize: "20",
- SettingKeyTablePageSizeOptions: "[10,20,50,100]",
- SettingKeyCustomMenuItems: "[]",
- SettingKeyCustomEndpoints: "[]",
- SettingKeyOIDCConnectEnabled: "false",
- SettingKeyOIDCConnectProviderName: "OIDC",
- SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
- SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
- SettingKeyDefaultSubscriptions: "[]",
- SettingKeySMTPPort: "587",
- SettingKeySMTPUseTLS: "false",
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "false",
+ SettingKeyRegistrationEmailSuffixWhitelist: "[]",
+ SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
+ SettingKeySiteName: "Sub2API",
+ SettingKeySiteLogo: "",
+ SettingKeyPurchaseSubscriptionEnabled: "false",
+ SettingKeyPurchaseSubscriptionURL: "",
+ SettingKeyTableDefaultPageSize: "20",
+ SettingKeyTablePageSizeOptions: "[10,20,50,100]",
+ SettingKeyCustomMenuItems: "[]",
+ SettingKeyCustomEndpoints: "[]",
+ SettingKeyWeChatConnectEnabled: "false",
+ SettingKeyWeChatConnectAppID: "",
+ SettingKeyWeChatConnectAppSecret: "",
+ SettingKeyWeChatConnectOpenAppID: "",
+ SettingKeyWeChatConnectOpenAppSecret: "",
+ SettingKeyWeChatConnectMPAppID: "",
+ SettingKeyWeChatConnectMPAppSecret: "",
+ SettingKeyWeChatConnectMobileAppID: "",
+ SettingKeyWeChatConnectMobileAppSecret: "",
+ SettingKeyWeChatConnectOpenEnabled: "false",
+ SettingKeyWeChatConnectMPEnabled: "false",
+ SettingKeyWeChatConnectMobileEnabled: "false",
+ SettingKeyWeChatConnectMode: "open",
+ SettingKeyWeChatConnectScopes: "snsapi_login",
+ SettingKeyWeChatConnectRedirectURL: "",
+ SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend,
+ SettingKeyOIDCConnectEnabled: "false",
+ SettingKeyOIDCConnectProviderName: "OIDC",
+ SettingKeyOIDCConnectClientID: "",
+ SettingKeyOIDCConnectClientSecret: "",
+ SettingKeyOIDCConnectIssuerURL: "",
+ SettingKeyOIDCConnectDiscoveryURL: "",
+ SettingKeyOIDCConnectAuthorizeURL: "",
+ SettingKeyOIDCConnectTokenURL: "",
+ SettingKeyOIDCConnectUserInfoURL: "",
+ SettingKeyOIDCConnectJWKSURL: "",
+ SettingKeyOIDCConnectScopes: "openid email profile",
+ SettingKeyOIDCConnectRedirectURL: "",
+ SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
+ SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
+ SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault),
+ SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault),
+ SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
+ SettingKeyOIDCConnectClockSkewSeconds: "120",
+ SettingKeyOIDCConnectRequireEmailVerified: "false",
+ SettingKeyOIDCConnectUserInfoEmailPath: "",
+ SettingKeyOIDCConnectUserInfoIDPath: "",
+ SettingKeyOIDCConnectUserInfoUsernamePath: "",
+ SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
+ SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
+ SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64),
+ SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault),
+ SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault),
+ SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64),
+ SettingKeyDefaultUserRPMLimit: "0",
+ SettingKeyDefaultSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultEmailBalance: "0",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "0",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultOIDCBalance: "0",
+ SettingKeyAuthSourceDefaultOIDCConcurrency: "5",
+ SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultWeChatBalance: "0",
+ SettingKeyAuthSourceDefaultWeChatConcurrency: "5",
+ SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
+ SettingKeyForceEmailOnThirdPartySignup: "false",
+ SettingKeySMTPPort: "587",
+ SettingKeySMTPUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
@@ -968,12 +1893,28 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOpsQueryModeDefault: "auto",
SettingKeyOpsMetricsIntervalSeconds: "60",
+ // Channel monitor defaults (enabled, 60s)
+ SettingKeyChannelMonitorEnabled: "true",
+ SettingKeyChannelMonitorDefaultIntervalSeconds: "60",
+
+ // Available channels feature (default disabled; opt-in)
+ SettingKeyAvailableChannelsEnabled: "false",
+
+ // Affiliate (邀请返利) feature (default disabled; opt-in)
+ SettingKeyAffiliateEnabled: "false",
+
// Claude Code version check (default: empty = disabled)
SettingKeyMinClaudeCodeVersion: "",
SettingKeyMaxClaudeCodeVersion: "",
// 分组隔离(默认不允许未分组 Key 调度)
- SettingKeyAllowUngroupedKeyScheduling: "false",
+ SettingKeyAllowUngroupedKeyScheduling: "false",
+ SettingKeyEnableAnthropicCacheTTL1hInjection: "false",
+ SettingPaymentVisibleMethodAlipaySource: "",
+ SettingPaymentVisibleMethodWxpaySource: "",
+ SettingPaymentVisibleMethodAlipayEnabled: "false",
+ SettingPaymentVisibleMethodWxpayEnabled: "false",
+ openAIAdvancedSchedulerSettingKey: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -1032,12 +1973,36 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
}
+ if rpm, err := strconv.Atoi(settings[SettingKeyDefaultUserRPMLimit]); err == nil && rpm >= 0 {
+ result.DefaultUserRPMLimit = rpm
+ }
+
// 解析浮点数类型
if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil {
result.DefaultBalance = balance
} else {
result.DefaultBalance = s.cfg.Default.UserBalance
}
+ if rebateRate, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebateRate], 64); err == nil {
+ result.AffiliateRebateRate = clampAffiliateRebateRate(rebateRate)
+ } else {
+ result.AffiliateRebateRate = AffiliateRebateRateDefault
+ }
+ if freezeHours, err := strconv.Atoi(settings[SettingKeyAffiliateRebateFreezeHours]); err == nil && freezeHours >= 0 {
+ if freezeHours > AffiliateRebateFreezeHoursMax {
+ freezeHours = AffiliateRebateFreezeHoursMax
+ }
+ result.AffiliateRebateFreezeHours = freezeHours
+ }
+ if durationDays, err := strconv.Atoi(settings[SettingKeyAffiliateRebateDurationDays]); err == nil && durationDays >= 0 {
+ if durationDays > AffiliateRebateDurationDaysMax {
+ durationDays = AffiliateRebateDurationDaysMax
+ }
+ result.AffiliateRebateDurationDays = durationDays
+ }
+ if perInviteeCap, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebatePerInviteeCap], 64); err == nil && perInviteeCap >= 0 {
+ result.AffiliateRebatePerInviteeCap = perInviteeCap
+ }
result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
// 敏感信息直接返回,方便测试连接时使用
@@ -1157,12 +2122,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
result.OIDCConnectUsePKCE = raw == "true"
} else {
- result.OIDCConnectUsePKCE = oidcBase.UsePKCE
+ result.OIDCConnectUsePKCE = oidcUsePKCECompatibilityDefault(oidcBase)
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
result.OIDCConnectValidateIDToken = raw == "true"
} else {
- result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
+ result.OIDCConnectValidateIDToken = oidcValidateIDTokenCompatibilityDefault(oidcBase)
}
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
@@ -1208,6 +2173,31 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
+ // WeChat Connect 设置:
+ // - 优先读取 DB 系统设置
+ // - 缺失时回退到 config/env,保持升级兼容
+ weChatEffective := s.effectiveWeChatConnectOAuthConfig(settings)
+ result.WeChatConnectEnabled = weChatEffective.Enabled
+ result.WeChatConnectAppID = weChatEffective.LegacyAppID
+ result.WeChatConnectAppSecret = weChatEffective.LegacyAppSecret
+ result.WeChatConnectAppSecretConfigured = weChatEffective.LegacyAppSecret != ""
+ result.WeChatConnectOpenAppID = weChatEffective.OpenAppID
+ result.WeChatConnectOpenAppSecret = weChatEffective.OpenAppSecret
+ result.WeChatConnectOpenAppSecretConfigured = weChatEffective.OpenAppSecret != ""
+ result.WeChatConnectMPAppID = weChatEffective.MPAppID
+ result.WeChatConnectMPAppSecret = weChatEffective.MPAppSecret
+ result.WeChatConnectMPAppSecretConfigured = weChatEffective.MPAppSecret != ""
+ result.WeChatConnectMobileAppID = weChatEffective.MobileAppID
+ result.WeChatConnectMobileAppSecret = weChatEffective.MobileAppSecret
+ result.WeChatConnectMobileAppSecretConfigured = weChatEffective.MobileAppSecret != ""
+ result.WeChatConnectOpenEnabled = weChatEffective.OpenEnabled
+ result.WeChatConnectMPEnabled = weChatEffective.MPEnabled
+ result.WeChatConnectMobileEnabled = weChatEffective.MobileEnabled
+ result.WeChatConnectMode = weChatEffective.Mode
+ result.WeChatConnectScopes = weChatEffective.Scopes
+ result.WeChatConnectRedirectURL = weChatEffective.RedirectURL
+ result.WeChatConnectFrontendRedirectURL = weChatEffective.FrontendRedirectURL
+
// Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
@@ -1240,6 +2230,18 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
}
+ // Channel monitor feature (default: enabled, 60s)
+ result.ChannelMonitorEnabled = !isFalseSettingValue(settings[SettingKeyChannelMonitorEnabled])
+ result.ChannelMonitorDefaultIntervalSeconds = parseChannelMonitorInterval(
+ settings[SettingKeyChannelMonitorDefaultIntervalSeconds],
+ )
+
+ // Available channels feature (default: disabled; strict true)
+ result.AvailableChannelsEnabled = settings[SettingKeyAvailableChannelsEnabled] == "true"
+
+ // Affiliate (邀请返利) feature (default: disabled; strict true)
+ result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true"
+
// Claude Code version check
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion]
@@ -1255,6 +2257,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
+ result.EnableAnthropicCacheTTL1hInjection = settings[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true"
// Web search emulation: quick enabled check from the JSON config
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
@@ -1263,6 +2266,11 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
}
}
+ result.PaymentVisibleMethodAlipaySource = NormalizeVisibleMethodSource("alipay", settings[SettingPaymentVisibleMethodAlipaySource])
+ result.PaymentVisibleMethodWxpaySource = NormalizeVisibleMethodSource("wxpay", settings[SettingPaymentVisibleMethodWxpaySource])
+ result.PaymentVisibleMethodAlipayEnabled = settings[SettingPaymentVisibleMethodAlipayEnabled] == "true"
+ result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true"
+ result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true"
// Balance low notification
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
@@ -1283,6 +2291,19 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
return result
}
+func clampAffiliateRebateRate(value float64) float64 {
+ if math.IsNaN(value) || math.IsInf(value, 0) {
+ return AffiliateRebateRateDefault
+ }
+ if value < AffiliateRebateRateMin {
+ return AffiliateRebateRateMin
+ }
+ if value > AffiliateRebateRateMax {
+ return AffiliateRebateRateMax
+ }
+ return value
+}
+
func isFalseSettingValue(value string) bool {
switch strings.ToLower(strings.TrimSpace(value)) {
case "false", "0", "off", "disabled":
@@ -1292,6 +2313,23 @@ func isFalseSettingValue(value string) bool {
}
}
+func normalizeVisibleMethodSettingSource(method, source string, enabled bool) (string, error) {
+ _ = enabled
+ source = strings.TrimSpace(source)
+ if source == "" {
+ return "", nil
+ }
+
+ normalized := NormalizeVisibleMethodSource(method, source)
+ if normalized == "" {
+ return "", infraerrors.BadRequest(
+ "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE",
+ fmt.Sprintf("%s source must be one of the supported payment providers", method),
+ )
+ }
+ return normalized, nil
+}
+
func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
raw = strings.TrimSpace(raw)
if raw == "" {
@@ -1317,6 +2355,73 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
return normalized
}
+func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings {
+ result := ProviderDefaultGrantSettings{
+ Balance: defaultAuthSourceBalance,
+ Concurrency: defaultAuthSourceConcurrency,
+ Subscriptions: []DefaultSubscriptionSetting{},
+ GrantOnSignup: false,
+ GrantOnFirstBind: false,
+ }
+
+ if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil {
+ result.Balance = v
+ }
+ if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil {
+ result.Concurrency = v
+ }
+ if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil {
+ result.Subscriptions = items
+ }
+ if raw, ok := settings[keys.grantOnSignup]; ok {
+ result.GrantOnSignup = raw == "true"
+ }
+ if raw, ok := settings[keys.grantOnFirstBind]; ok {
+ result.GrantOnFirstBind = raw == "true"
+ }
+
+ return result
+}
+
+func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) {
+ updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64)
+ updates[keys.concurrency] = strconv.Itoa(settings.Concurrency)
+
+ subscriptions := settings.Subscriptions
+ if subscriptions == nil {
+ subscriptions = []DefaultSubscriptionSetting{}
+ }
+ raw, err := json.Marshal(subscriptions)
+ if err != nil {
+ raw = []byte("[]")
+ }
+ updates[keys.subscriptions] = string(raw)
+ updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup)
+ updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
+}
+
+func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings {
+ result := ProviderDefaultGrantSettings{
+ Balance: globalDefaults.Balance,
+ Concurrency: globalDefaults.Concurrency,
+ Subscriptions: append([]DefaultSubscriptionSetting(nil), globalDefaults.Subscriptions...),
+ GrantOnSignup: providerDefaults.GrantOnSignup,
+ GrantOnFirstBind: providerDefaults.GrantOnFirstBind,
+ }
+
+ if providerDefaults.Balance != defaultAuthSourceBalance {
+ result.Balance = providerDefaults.Balance
+ }
+ if providerDefaults.Concurrency > 0 && providerDefaults.Concurrency != defaultAuthSourceConcurrency {
+ result.Concurrency = providerDefaults.Concurrency
+ }
+ if len(providerDefaults.Subscriptions) > 0 {
+ result.Subscriptions = append([]DefaultSubscriptionSetting(nil), providerDefaults.Subscriptions...)
+ }
+
+ return result
+}
+
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
defaultPageSize := 20
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
@@ -1539,7 +2644,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.RedirectURL = strings.TrimSpace(v)
}
-
if !effective.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
@@ -1587,9 +2691,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
- if !effective.UsePKCE {
- return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
- }
default:
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
@@ -1597,6 +2698,35 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return effective, nil
}
+// GetWeChatConnectOAuthConfig 返回用于登录的最终生效 WeChat Connect 配置。
+//
+// WeChat Connect 已回归 DB 系统设置模型,不再回退到 config/env。
+func (s *SettingService) GetWeChatConnectOAuthConfig(ctx context.Context) (WeChatConnectOAuthConfig, error) {
+ keys := []string{
+ SettingKeyWeChatConnectEnabled,
+ SettingKeyWeChatConnectAppID,
+ SettingKeyWeChatConnectAppSecret,
+ SettingKeyWeChatConnectOpenAppID,
+ SettingKeyWeChatConnectOpenAppSecret,
+ SettingKeyWeChatConnectMPAppID,
+ SettingKeyWeChatConnectMPAppSecret,
+ SettingKeyWeChatConnectMobileAppID,
+ SettingKeyWeChatConnectMobileAppSecret,
+ SettingKeyWeChatConnectOpenEnabled,
+ SettingKeyWeChatConnectMPEnabled,
+ SettingKeyWeChatConnectMobileEnabled,
+ SettingKeyWeChatConnectMode,
+ SettingKeyWeChatConnectScopes,
+ SettingKeyWeChatConnectRedirectURL,
+ SettingKeyWeChatConnectFrontendRedirectURL,
+ }
+ settings, err := s.settingRepo.GetMultiple(ctx, keys)
+ if err != nil {
+ return WeChatConnectOAuthConfig{}, fmt.Errorf("get wechat connect settings: %w", err)
+ }
+ return s.parseWeChatConnectOAuthConfig(settings)
+}
+
// GetOverloadCooldownSettings 获取529过载冷却配置
func (s *SettingService) GetOverloadCooldownSettings(ctx context.Context) (*OverloadCooldownSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyOverloadCooldownSettings)
@@ -1733,9 +2863,13 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
}
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
effective.UsePKCE = raw == "true"
+ } else {
+ effective.UsePKCE = oidcUsePKCECompatibilityDefault(effective)
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true"
+ } else {
+ effective.ValidateIDToken = oidcValidateIDTokenCompatibilityDefault(effective)
}
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v)
@@ -1864,9 +2998,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
- if !effective.UsePKCE {
- return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
- }
default:
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
@@ -2158,6 +3289,84 @@ func (s *SettingService) SetBetaPolicySettings(ctx context.Context, settings *Be
return s.settingRepo.Set(ctx, SettingKeyBetaPolicySettings, string(data))
}
+// GetOpenAIFastPolicySettings 获取 OpenAI fast 策略配置
+func (s *SettingService) GetOpenAIFastPolicySettings(ctx context.Context) (*OpenAIFastPolicySettings, error) {
+ value, err := s.settingRepo.GetValue(ctx, SettingKeyOpenAIFastPolicySettings)
+ if err != nil {
+ if errors.Is(err, ErrSettingNotFound) {
+ return DefaultOpenAIFastPolicySettings(), nil
+ }
+ return nil, fmt.Errorf("get openai fast policy settings: %w", err)
+ }
+ if value == "" {
+ return DefaultOpenAIFastPolicySettings(), nil
+ }
+
+ var settings OpenAIFastPolicySettings
+ if err := json.Unmarshal([]byte(value), &settings); err != nil {
+ // JSON 损坏时静默 fallback 到默认配置会让策略意外失效(管理员配
+ // 置的 block/filter 规则被忽略)。记录 Warn 让运维能在出现异常
+ // 行为时定位到 settings 表里的脏数据。
+ slog.Warn("failed to unmarshal openai fast policy settings, falling back to defaults",
+ "error", err,
+ "key", SettingKeyOpenAIFastPolicySettings)
+ return DefaultOpenAIFastPolicySettings(), nil
+ }
+
+ return &settings, nil
+}
+
+// SetOpenAIFastPolicySettings 设置 OpenAI fast 策略配置
+func (s *SettingService) SetOpenAIFastPolicySettings(ctx context.Context, settings *OpenAIFastPolicySettings) error {
+ if settings == nil {
+ return fmt.Errorf("settings cannot be nil")
+ }
+
+ validActions := map[string]bool{
+ BetaPolicyActionPass: true, BetaPolicyActionFilter: true, BetaPolicyActionBlock: true,
+ }
+ validScopes := map[string]bool{
+ BetaPolicyScopeAll: true, BetaPolicyScopeOAuth: true, BetaPolicyScopeAPIKey: true, BetaPolicyScopeBedrock: true,
+ }
+ validTiers := map[string]bool{
+ OpenAIFastTierAny: true, OpenAIFastTierPriority: true, OpenAIFastTierFlex: true,
+ }
+
+ for i, rule := range settings.Rules {
+ tier := strings.ToLower(strings.TrimSpace(rule.ServiceTier))
+ if tier == "" {
+ tier = OpenAIFastTierAny
+ }
+ if !validTiers[tier] {
+ return fmt.Errorf("rule[%d]: invalid service_tier %q", i, rule.ServiceTier)
+ }
+ settings.Rules[i].ServiceTier = tier
+ if !validActions[rule.Action] {
+ return fmt.Errorf("rule[%d]: invalid action %q", i, rule.Action)
+ }
+ if !validScopes[rule.Scope] {
+ return fmt.Errorf("rule[%d]: invalid scope %q", i, rule.Scope)
+ }
+ for j, pattern := range rule.ModelWhitelist {
+ trimmed := strings.TrimSpace(pattern)
+ if trimmed == "" {
+ return fmt.Errorf("rule[%d]: model_whitelist[%d] cannot be empty", i, j)
+ }
+ settings.Rules[i].ModelWhitelist[j] = trimmed
+ }
+ if rule.FallbackAction != "" && !validActions[rule.FallbackAction] {
+ return fmt.Errorf("rule[%d]: invalid fallback_action %q", i, rule.FallbackAction)
+ }
+ }
+
+ data, err := json.Marshal(settings)
+ if err != nil {
+ return fmt.Errorf("marshal openai fast policy settings: %w", err)
+ }
+
+ return s.settingRepo.Set(ctx, SettingKeyOpenAIFastPolicySettings, string(data))
+}
+
// SetStreamTimeoutSettings 设置流超时处理配置
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
if settings == nil {
diff --git a/backend/internal/service/setting_service_auth_source_defaults_test.go b/backend/internal/service/setting_service_auth_source_defaults_test.go
new file mode 100644
index 00000000..1ff49740
--- /dev/null
+++ b/backend/internal/service/setting_service_auth_source_defaults_test.go
@@ -0,0 +1,138 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type authSourceDefaultsRepoStub struct {
+ values map[string]string
+ updates map[string]string
+}
+
+func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ s.updates = make(map[string]string, len(settings))
+ for key, value := range settings {
+ s.updates[key] = value
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) {
+ repo := &authSourceDefaultsRepoStub{
+ values: map[string]string{
+ SettingKeyAuthSourceDefaultEmailBalance: "12.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "7",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true",
+ SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+
+ got, err := svc.GetAuthSourceDefaultSettings(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, 12.5, got.Email.Balance)
+ require.Equal(t, 7, got.Email.Concurrency)
+ require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions)
+ require.False(t, got.Email.GrantOnSignup)
+ require.False(t, got.Email.GrantOnFirstBind)
+ require.Equal(t, 0.0, got.LinuxDo.Balance)
+ require.Equal(t, 5, got.LinuxDo.Concurrency)
+ require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions)
+ require.False(t, got.LinuxDo.GrantOnSignup)
+ require.True(t, got.LinuxDo.GrantOnFirstBind)
+ require.Equal(t, 5, got.OIDC.Concurrency)
+ require.Equal(t, 5, got.WeChat.Concurrency)
+ require.False(t, got.OIDC.GrantOnSignup)
+ require.False(t, got.WeChat.GrantOnSignup)
+ require.True(t, got.ForceEmailOnThirdPartySignup)
+}
+
+func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) {
+ repo := &authSourceDefaultsRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{
+ Email: ProviderDefaultGrantSettings{
+ Balance: 1.25,
+ Concurrency: 3,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}},
+ GrantOnSignup: false,
+ GrantOnFirstBind: true,
+ },
+ LinuxDo: ProviderDefaultGrantSettings{
+ Balance: 2,
+ Concurrency: 4,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}},
+ GrantOnSignup: true,
+ GrantOnFirstBind: false,
+ },
+ OIDC: ProviderDefaultGrantSettings{
+ Balance: 3,
+ Concurrency: 5,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}},
+ GrantOnSignup: true,
+ GrantOnFirstBind: true,
+ },
+ WeChat: ProviderDefaultGrantSettings{
+ Balance: 4,
+ Concurrency: 6,
+ Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}},
+ GrantOnSignup: false,
+ GrantOnFirstBind: false,
+ },
+ ForceEmailOnThirdPartySignup: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance])
+ require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency])
+ require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup])
+ require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind])
+ require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup])
+
+ var got []DefaultSubscriptionSetting
+ require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got))
+ require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got)
+}
diff --git a/backend/internal/service/setting_service_oidc_config_test.go b/backend/internal/service/setting_service_oidc_config_test.go
index 3809b332..61324204 100644
--- a/backend/internal/service/setting_service_oidc_config_test.go
+++ b/backend/internal/service/setting_service_oidc_config_test.go
@@ -101,3 +101,151 @@ func TestGetOIDCConnectOAuthConfig_ResolvesEndpointsFromIssuerDiscovery(t *testi
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/userinfo", got.UserInfoURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/certs", got.JWKSURL)
}
+
+func TestSettingService_ParseSettings_PreservesOptionalOIDCCompatibilityFlags(t *testing.T) {
+ svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{})
+
+ got := svc.parseSettings(map[string]string{
+ SettingKeyOIDCConnectEnabled: "true",
+ SettingKeyOIDCConnectUsePKCE: "false",
+ SettingKeyOIDCConnectValidateIDToken: "false",
+ })
+
+ require.False(t, got.OIDCConnectUsePKCE)
+ require.False(t, got.OIDCConnectValidateIDToken)
+}
+
+func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValues(t *testing.T) {
+ svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
+ OIDC: config.OIDCConnectConfig{
+ UsePKCE: true,
+ UsePKCEExplicit: true,
+ ValidateIDToken: true,
+ ValidateIDTokenExplicit: true,
+ },
+ })
+
+ got := svc.parseSettings(map[string]string{
+ SettingKeyOIDCConnectEnabled: "true",
+ })
+
+ require.True(t, got.OIDCConnectUsePKCE)
+ require.True(t, got.OIDCConnectValidateIDToken)
+}
+
+func TestSettingService_ParseSettings_DefaultsOIDCCompatibilityFlagsToSafeDefaultsWhenSettingsMissing(t *testing.T) {
+ svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
+ OIDC: config.OIDCConnectConfig{
+ UsePKCE: true,
+ ValidateIDToken: true,
+ },
+ })
+
+ got := svc.parseSettings(map[string]string{
+ SettingKeyOIDCConnectEnabled: "true",
+ })
+
+ require.True(t, got.OIDCConnectUsePKCE)
+ require.True(t, got.OIDCConnectValidateIDToken)
+}
+
+func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
+ cfg := &config.Config{
+ OIDC: config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "OIDC",
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/auth",
+ TokenURL: "https://issuer.example.com/token",
+ UserInfoURL: "https://issuer.example.com/userinfo",
+ RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ Scopes: "openid email profile",
+ TokenAuthMethod: "client_secret_post",
+ },
+ }
+
+ repo := &settingOIDCRepoStub{values: map[string]string{
+ SettingKeyOIDCConnectEnabled: "true",
+ SettingKeyOIDCConnectUsePKCE: "false",
+ SettingKeyOIDCConnectValidateIDToken: "false",
+ }}
+ svc := NewSettingService(repo, cfg)
+
+ got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.False(t, got.UsePKCE)
+ require.False(t, got.ValidateIDToken)
+}
+
+func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *testing.T) {
+ cfg := &config.Config{
+ OIDC: config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "OIDC",
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/auth",
+ TokenURL: "https://issuer.example.com/token",
+ UserInfoURL: "https://issuer.example.com/userinfo",
+ JWKSURL: "https://issuer.example.com/jwks",
+ RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ Scopes: "openid email profile",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ UsePKCEExplicit: true,
+ ValidateIDToken: true,
+ ValidateIDTokenExplicit: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ },
+ }
+
+ repo := &settingOIDCRepoStub{values: map[string]string{
+ SettingKeyOIDCConnectEnabled: "true",
+ }}
+ svc := NewSettingService(repo, cfg)
+
+ got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.True(t, got.UsePKCE)
+ require.True(t, got.ValidateIDToken)
+}
+
+func TestGetOIDCConnectOAuthConfig_DefaultsCompatibilityFlagsToSafeValuesWhenSettingsMissing(t *testing.T) {
+ cfg := &config.Config{
+ OIDC: config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "OIDC",
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/auth",
+ TokenURL: "https://issuer.example.com/token",
+ UserInfoURL: "https://issuer.example.com/userinfo",
+ JWKSURL: "https://issuer.example.com/jwks",
+ RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ Scopes: "openid email profile",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ },
+ }
+
+ repo := &settingOIDCRepoStub{values: map[string]string{
+ SettingKeyOIDCConnectEnabled: "true",
+ }}
+ svc := NewSettingService(repo, cfg)
+
+ got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.True(t, got.UsePKCE)
+ require.True(t, got.ValidateIDToken)
+}
diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go
index 5cf1e860..1ecd4e6f 100644
--- a/backend/internal/service/setting_service_public_test.go
+++ b/backend/internal/service/setting_service_public_test.go
@@ -77,3 +77,77 @@ func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T)
require.Equal(t, 50, settings.TableDefaultPageSize)
require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions)
}
+
+func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
+ repo := &settingPublicRepoStub{
+ values: map[string]string{
+ SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+
+ settings, err := svc.GetPublicSettings(context.Background())
+ require.NoError(t, err)
+ require.True(t, settings.ForceEmailOnThirdPartySignup)
+}
+
+func TestSettingService_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) {
+ svc := NewSettingService(&settingPublicRepoStub{
+ values: map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx-mp-app",
+ SettingKeyWeChatConnectAppSecret: "wx-mp-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectOpenEnabled: "true",
+ SettingKeyWeChatConnectMPEnabled: "true",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ },
+ }, &config.Config{})
+
+ settings, err := svc.GetPublicSettings(context.Background())
+ require.NoError(t, err)
+ require.True(t, settings.WeChatOAuthEnabled)
+ require.True(t, settings.WeChatOAuthOpenEnabled)
+ require.True(t, settings.WeChatOAuthMPEnabled)
+}
+
+func TestSettingService_GetPublicSettings_DoesNotExposeMobileOnlyWeChatAsWebOAuthAvailable(t *testing.T) {
+ svc := NewSettingService(&settingPublicRepoStub{
+ values: map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectMobileEnabled: "true",
+ SettingKeyWeChatConnectMode: "mobile",
+ SettingKeyWeChatConnectMobileAppID: "wx-mobile-app",
+ SettingKeyWeChatConnectMobileAppSecret: "wx-mobile-secret",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ },
+ }, &config.Config{})
+
+ settings, err := svc.GetPublicSettings(context.Background())
+ require.NoError(t, err)
+ require.False(t, settings.WeChatOAuthEnabled)
+ require.False(t, settings.WeChatOAuthOpenEnabled)
+ require.False(t, settings.WeChatOAuthMPEnabled)
+ require.True(t, settings.WeChatOAuthMobileEnabled)
+}
+
+func TestSettingService_GetPublicSettings_FallsBackToConfigForWeChatOAuthCapabilities(t *testing.T) {
+ svc := NewSettingService(&settingPublicRepoStub{values: map[string]string{}}, &config.Config{
+ WeChat: config.WeChatConnectConfig{
+ Enabled: true,
+ OpenEnabled: true,
+ OpenAppID: "wx-open-config",
+ OpenAppSecret: "wx-open-secret",
+ FrontendRedirectURL: "/auth/wechat/config-callback",
+ },
+ })
+
+ settings, err := svc.GetPublicSettings(context.Background())
+ require.NoError(t, err)
+ require.True(t, settings.WeChatOAuthEnabled)
+ require.True(t, settings.WeChatOAuthOpenEnabled)
+ require.False(t, settings.WeChatOAuthMPEnabled)
+ require.False(t, settings.WeChatOAuthMobileEnabled)
+}
diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go
index e62218b4..9dc0ca59 100644
--- a/backend/internal/service/setting_service_update_test.go
+++ b/backend/internal/service/setting_service_update_test.go
@@ -223,3 +223,34 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
require.Equal(t, "1000", repo.updates[SettingKeyTableDefaultPageSize])
require.Equal(t, "[20,100]", repo.updates[SettingKeyTablePageSizeOptions])
}
+
+func TestSettingService_UpdateSettings_PaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ PaymentVisibleMethodAlipaySource: "alipay",
+ PaymentVisibleMethodWxpaySource: "easypay",
+ PaymentVisibleMethodAlipayEnabled: true,
+ PaymentVisibleMethodWxpayEnabled: false,
+ OpenAIAdvancedSchedulerEnabled: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, VisibleMethodSourceOfficialAlipay, repo.updates[SettingPaymentVisibleMethodAlipaySource])
+ require.Equal(t, VisibleMethodSourceEasyPayWechat, repo.updates[SettingPaymentVisibleMethodWxpaySource])
+ require.Equal(t, "true", repo.updates[SettingPaymentVisibleMethodAlipayEnabled])
+ require.Equal(t, "false", repo.updates[SettingPaymentVisibleMethodWxpayEnabled])
+ require.Equal(t, "true", repo.updates[openAIAdvancedSchedulerSettingKey])
+}
+
+func TestSettingService_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ PaymentVisibleMethodAlipaySource: "not-a-provider",
+ })
+ require.Error(t, err)
+ require.Equal(t, "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", infraerrors.Reason(err))
+ require.Nil(t, repo.updates)
+}
diff --git a/backend/internal/service/setting_service_wechat_config_test.go b/backend/internal/service/setting_service_wechat_config_test.go
new file mode 100644
index 00000000..a2de614b
--- /dev/null
+++ b/backend/internal/service/setting_service_wechat_config_test.go
@@ -0,0 +1,162 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type settingWeChatRepoStub struct {
+ values map[string]string
+}
+
+func (s *settingWeChatRepoStub) Get(context.Context, string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingWeChatRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if value, ok := s.values[key]; ok {
+ return value, nil
+ }
+ return "", ErrSettingNotFound
+}
+
+func (s *settingWeChatRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingWeChatRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingWeChatRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *settingWeChatRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *settingWeChatRepoStub) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingService_GetWeChatConnectOAuthConfig_UsesDatabaseOverrides(t *testing.T) {
+ repo := &settingWeChatRepoStub{
+ values: map[string]string{
+ SettingKeyWeChatConnectEnabled: "true",
+ SettingKeyWeChatConnectAppID: "wx-db-app",
+ SettingKeyWeChatConnectAppSecret: "wx-db-secret",
+ SettingKeyWeChatConnectMode: "mp",
+ SettingKeyWeChatConnectScopes: "snsapi_base",
+ SettingKeyWeChatConnectOpenEnabled: "true",
+ SettingKeyWeChatConnectMPEnabled: "true",
+ SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+
+ got, err := svc.GetWeChatConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.True(t, got.Enabled)
+ require.Equal(t, "wx-db-app", got.AppIDForMode("mp"))
+ require.Equal(t, "wx-db-secret", got.AppSecretForMode("mp"))
+ require.True(t, got.OpenEnabled)
+ require.True(t, got.MPEnabled)
+ require.Equal(t, "mp", got.Mode)
+ require.Equal(t, "snsapi_base", got.Scopes)
+ require.Equal(t, "https://api.example.com/api/v1/auth/oauth/wechat/callback", got.RedirectURL)
+ require.Equal(t, "/auth/wechat/callback", got.FrontendRedirectURL)
+}
+
+func TestSettingService_GetWeChatConnectOAuthConfig_FallsBackToConfigWhenDatabaseEmpty(t *testing.T) {
+ repo := &settingWeChatRepoStub{values: map[string]string{}}
+ svc := NewSettingService(repo, &config.Config{
+ WeChat: config.WeChatConnectConfig{
+ Enabled: true,
+ OpenEnabled: true,
+ MPEnabled: true,
+ Mode: "open",
+ OpenAppID: "wx-open-config",
+ OpenAppSecret: "wx-open-secret",
+ MPAppID: "wx-mp-config",
+ MPAppSecret: "wx-mp-secret",
+ FrontendRedirectURL: "/auth/wechat/config-callback",
+ },
+ })
+
+ got, err := svc.GetWeChatConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.True(t, got.Enabled)
+ require.True(t, got.OpenEnabled)
+ require.True(t, got.MPEnabled)
+ require.Equal(t, "wx-open-config", got.AppIDForMode("open"))
+ require.Equal(t, "wx-open-secret", got.AppSecretForMode("open"))
+ require.Equal(t, "wx-mp-config", got.AppIDForMode("mp"))
+ require.Equal(t, "wx-mp-secret", got.AppSecretForMode("mp"))
+ require.Equal(t, "/auth/wechat/config-callback", got.FrontendRedirectURL)
+ require.Empty(t, got.RedirectURL)
+}
+
+func TestSettingService_GetWeChatConnectOAuthConfig_IgnoresSyntheticDisabledCapabilitiesFromMigration118(t *testing.T) {
+ repo := &settingWeChatRepoStub{
+ values: map[string]string{
+ SettingKeyWeChatConnectOpenEnabled: "false",
+ SettingKeyWeChatConnectMPEnabled: "false",
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{
+ WeChat: config.WeChatConnectConfig{
+ Enabled: true,
+ OpenEnabled: true,
+ MPEnabled: true,
+ Mode: "open",
+ OpenAppID: "wx-open-config",
+ OpenAppSecret: "wx-open-secret",
+ MPAppID: "wx-mp-config",
+ MPAppSecret: "wx-mp-secret",
+ FrontendRedirectURL: "/auth/wechat/config-callback",
+ },
+ })
+
+ got, err := svc.GetWeChatConnectOAuthConfig(context.Background())
+ require.NoError(t, err)
+ require.True(t, got.Enabled)
+ require.True(t, got.OpenEnabled)
+ require.True(t, got.MPEnabled)
+ require.Equal(t, "wx-open-config", got.AppIDForMode("open"))
+ require.Equal(t, "wx-mp-config", got.AppIDForMode("mp"))
+}
+
+func TestSettingService_ParseSettings_FallsBackToConfigForWeChatAdminView(t *testing.T) {
+ svc := NewSettingService(&settingWeChatRepoStub{values: map[string]string{}}, &config.Config{
+ WeChat: config.WeChatConnectConfig{
+ Enabled: true,
+ OpenEnabled: true,
+ Mode: "open",
+ OpenAppID: "wx-open-config",
+ OpenAppSecret: "wx-open-secret",
+ FrontendRedirectURL: "/auth/wechat/config-callback",
+ },
+ })
+
+ got := svc.parseSettings(map[string]string{})
+ require.True(t, got.WeChatConnectEnabled)
+ require.True(t, got.WeChatConnectOpenEnabled)
+ require.Equal(t, "wx-open-config", got.WeChatConnectOpenAppID)
+ require.True(t, got.WeChatConnectOpenAppSecretConfigured)
+ require.Equal(t, "/auth/wechat/config-callback", got.WeChatConnectFrontendRedirectURL)
+ require.Equal(t, "open", got.WeChatConnectMode)
+ require.Equal(t, "snsapi_login", got.WeChatConnectScopes)
+}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index ab2eb274..41c01cca 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -1,5 +1,16 @@
package service
+import "strings"
+
+func firstNonEmpty(values ...string) string {
+ for _, value := range values {
+ if trimmed := strings.TrimSpace(value); trimmed != "" {
+ return trimmed
+ }
+ }
+ return ""
+}
+
type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
@@ -31,6 +42,28 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool
LinuxDoConnectRedirectURL string
+ // WeChat Connect OAuth 登录
+ WeChatConnectEnabled bool
+ WeChatConnectAppID string
+ WeChatConnectAppSecret string
+ WeChatConnectAppSecretConfigured bool
+ WeChatConnectOpenAppID string
+ WeChatConnectOpenAppSecret string
+ WeChatConnectOpenAppSecretConfigured bool
+ WeChatConnectMPAppID string
+ WeChatConnectMPAppSecret string
+ WeChatConnectMPAppSecretConfigured bool
+ WeChatConnectMobileAppID string
+ WeChatConnectMobileAppSecret string
+ WeChatConnectMobileAppSecretConfigured bool
+ WeChatConnectOpenEnabled bool
+ WeChatConnectMPEnabled bool
+ WeChatConnectMobileEnabled bool
+ WeChatConnectMode string
+ WeChatConnectScopes string
+ WeChatConnectRedirectURL string
+ WeChatConnectFrontendRedirectURL string
+
// Generic OIDC OAuth 登录
OIDCConnectEnabled bool
OIDCConnectProviderName string
@@ -71,9 +104,15 @@ type SystemSettings struct {
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
- DefaultConcurrency int
- DefaultBalance float64
- DefaultSubscriptions []DefaultSubscriptionSetting
+ DefaultConcurrency int
+ DefaultBalance float64
+ AffiliateEnabled bool
+ AffiliateRebateRate float64
+ AffiliateRebateFreezeHours int
+ AffiliateRebateDurationDays int
+ AffiliateRebatePerInviteeCap float64
+ DefaultUserRPMLimit int
+ DefaultSubscriptions []DefaultSubscriptionSetting
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -92,6 +131,13 @@ type SystemSettings struct {
OpsQueryModeDefault string
OpsMetricsIntervalSeconds int
+ // Channel Monitor feature
+ ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
+
+ // Available Channels feature (user-facing aggregate view)
+ AvailableChannelsEnabled bool `json:"available_channels_enabled"`
+
// Claude Code version check
MinClaudeCodeVersion string
MaxClaudeCodeVersion string
@@ -103,13 +149,23 @@ type SystemSettings struct {
BackendModeEnabled bool
// Gateway forwarding behavior
- EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
- EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
- EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
+ EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
+ EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
+ EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
+ EnableAnthropicCacheTTL1hInjection bool // 是否对 Anthropic OAuth/SetupToken 请求体注入 1h cache_control ttl(默认 false)
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource string
+ PaymentVisibleMethodWxpaySource string
+ PaymentVisibleMethodAlipayEnabled bool
+ PaymentVisibleMethodWxpayEnabled bool
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled bool
+
// Balance low notification
BalanceLowNotifyEnabled bool
BalanceLowNotifyThreshold float64
@@ -128,6 +184,7 @@ type DefaultSubscriptionSetting struct {
type PublicSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
+ ForceEmailOnThirdPartySignup bool
RegistrationEmailSuffixWhitelist []string
PromoCodeEnabled bool
PasswordResetEnabled bool
@@ -151,17 +208,91 @@ type PublicSettings struct {
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
- LinuxDoOAuthEnabled bool
- BackendModeEnabled bool
- PaymentEnabled bool
- OIDCOAuthEnabled bool
- OIDCOAuthProviderName string
- Version string
+ LinuxDoOAuthEnabled bool
+ WeChatOAuthEnabled bool
+ WeChatOAuthOpenEnabled bool
+ WeChatOAuthMPEnabled bool
+ WeChatOAuthMobileEnabled bool
+ BackendModeEnabled bool
+ PaymentEnabled bool
+ OIDCOAuthEnabled bool
+ OIDCOAuthProviderName string
+ Version string
BalanceLowNotifyEnabled bool
AccountQuotaNotifyEnabled bool
BalanceLowNotifyThreshold float64
BalanceLowNotifyRechargeURL string
+
+ // Channel Monitor feature
+ ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
+
+ // Available Channels feature (user-facing aggregate view)
+ AvailableChannelsEnabled bool `json:"available_channels_enabled"`
+
+ // Affiliate (邀请返利) feature toggle
+ AffiliateEnabled bool `json:"affiliate_enabled"`
+}
+
+type WeChatConnectOAuthConfig struct {
+ Enabled bool
+ LegacyAppID string
+ LegacyAppSecret string
+ OpenAppID string
+ OpenAppSecret string
+ MPAppID string
+ MPAppSecret string
+ MobileAppID string
+ MobileAppSecret string
+ OpenEnabled bool
+ MPEnabled bool
+ MobileEnabled bool
+ Mode string
+ Scopes string
+ RedirectURL string
+ FrontendRedirectURL string
+}
+
+func (cfg WeChatConnectOAuthConfig) SupportsMode(mode string) bool {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ return cfg.MPEnabled
+ case "mobile":
+ return cfg.MobileEnabled
+ default:
+ return cfg.OpenEnabled
+ }
+}
+
+func (cfg WeChatConnectOAuthConfig) ScopeForMode(mode string) string {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ return normalizeWeChatConnectScopeSetting(cfg.Scopes, "mp")
+ case "mobile":
+ return ""
+ }
+ return defaultWeChatConnectScopeForMode("open")
+}
+
+func (cfg WeChatConnectOAuthConfig) AppIDForMode(mode string) string {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ return strings.TrimSpace(firstNonEmpty(cfg.MPAppID, cfg.LegacyAppID))
+ case "mobile":
+ return strings.TrimSpace(firstNonEmpty(cfg.MobileAppID, cfg.LegacyAppID))
+ }
+ return strings.TrimSpace(firstNonEmpty(cfg.OpenAppID, cfg.LegacyAppID))
+}
+
+func (cfg WeChatConnectOAuthConfig) AppSecretForMode(mode string) string {
+ switch normalizeWeChatConnectModeSetting(mode) {
+ case "mp":
+ return strings.TrimSpace(firstNonEmpty(cfg.MPAppSecret, cfg.LegacyAppSecret))
+ case "mobile":
+ return strings.TrimSpace(firstNonEmpty(cfg.MobileAppSecret, cfg.LegacyAppSecret))
+ }
+ return strings.TrimSpace(firstNonEmpty(cfg.OpenAppSecret, cfg.LegacyAppSecret))
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
@@ -275,3 +406,57 @@ func DefaultBetaPolicySettings() *BetaPolicySettings {
},
}
}
+
+// OpenAI Fast Policy 策略常量
+// OpenAI 的 "fast 模式" 通过请求体中的 service_tier 字段识别:
+// - "priority"(客户端可传 "fast",归一化为 "priority"):fast 模式
+// - "flex":低优先级模式
+// - 省略:normal 默认
+//
+// 本策略复用 BetaPolicyAction*/BetaPolicyScope* 常量语义,只是匹配键从
+// anthropic-beta header 换成 body 的 service_tier 字段。
+const (
+ OpenAIFastTierAny = "all" // 匹配任意已识别的 service_tier
+ OpenAIFastTierPriority = "priority" // 仅匹配 fast(priority)
+ OpenAIFastTierFlex = "flex" // 仅匹配 flex
+)
+
+// OpenAIFastPolicyRule 单条 OpenAI fast/flex 策略规则
+type OpenAIFastPolicyRule struct {
+ ServiceTier string `json:"service_tier"` // "priority" | "flex" | "auto" | "default" | "scale" | "all"
+ Action string `json:"action"` // "pass" | "filter" | "block"
+ Scope string `json:"scope"` // "all" | "oauth" | "apikey" | "bedrock"
+ ErrorMessage string `json:"error_message,omitempty"` // 自定义错误消息 (action=block 时生效)
+ ModelWhitelist []string `json:"model_whitelist,omitempty"` // 模型匹配模式列表(为空=对所有模型生效)
+ FallbackAction string `json:"fallback_action,omitempty"` // 未匹配白名单的模型的处理方式
+ FallbackErrorMessage string `json:"fallback_error_message,omitempty"` // 未匹配白名单时的自定义错误消息 (fallback_action=block 时生效)
+}
+
+// OpenAIFastPolicySettings OpenAI fast 策略配置
+type OpenAIFastPolicySettings struct {
+ Rules []OpenAIFastPolicyRule `json:"rules"`
+}
+
+// DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。
+// 默认对所有模型的 priority(fast)请求执行 filter,即剔除 service_tier 字段,
+// 让上游按 normal 优先级处理。
+//
+// 为什么 ModelWhitelist 为空(=对所有模型生效):
+// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使
+// 用户使用 gpt-4 + fast,priority 配额仍会被消耗。如果默认规则只锁
+// gpt-5.5*,"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。
+// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定
+// 模型,可在 admin UI 中显式配置 model_whitelist。
+func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings {
+ return &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{
+ {
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ ModelWhitelist: []string{},
+ FallbackAction: BetaPolicyActionPass,
+ },
+ },
+ }
+}
diff --git a/backend/internal/service/sql_errors.go b/backend/internal/service/sql_errors.go
new file mode 100644
index 00000000..7c0155a4
--- /dev/null
+++ b/backend/internal/service/sql_errors.go
@@ -0,0 +1,14 @@
+package service
+
+import (
+ "database/sql"
+ "errors"
+ "strings"
+)
+
+func isSQLNoRowsError(err error) bool {
+ if err == nil {
+ return false
+ }
+ return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set")
+}
diff --git a/backend/internal/service/sticky_session_test.go b/backend/internal/service/sticky_session_test.go
index e7ef8982..11ace7bd 100644
--- a/backend/internal/service/sticky_session_test.go
+++ b/backend/internal/service/sticky_session_test.go
@@ -15,20 +15,8 @@ import (
"github.com/stretchr/testify/require"
)
-// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
-// 验证在以下情况下是否正确判断需要清理粘性会话:
-// - nil 账号:不清理(返回 false)
-// - 状态为错误或禁用:清理
-// - 不可调度:清理
-// - 临时不可调度且未过期:清理
-// - 临时不可调度已过期:不清理
-// - 正常可调度状态:不清理
-// - 模型限流(任意时长):清理
-//
-// TestShouldClearStickySession tests the sticky session clearing logic.
-// Verifies correct behavior for various account states including:
-// nil account, error/disabled status, unschedulable, temporary unschedulable,
-// and model rate limiting scenarios.
+// TestShouldClearStickySession tests sticky session clearing via IsSchedulable() delegation
+// plus model-level rate limiting.
func TestShouldClearStickySession(t *testing.T) {
now := time.Now()
future := now.Add(1 * time.Hour)
@@ -101,6 +89,56 @@ func TestShouldClearStickySession(t *testing.T) {
requestedModel: "claude-opus-4", // 请求不同模型
want: false, // 不同模型不受影响
},
+ {
+ name: "apikey quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ requestedModel: "",
+ want: true,
+ },
+ {
+ name: "oauth quota exceeded not cleared",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ requestedModel: "",
+ want: false,
+ },
+ {
+ name: "overloaded account",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ OverloadUntil: &future,
+ },
+ requestedModel: "",
+ want: true,
+ },
+ {
+ name: "account-level rate limited",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ RateLimitResetAt: &future,
+ },
+ requestedModel: "",
+ want: true,
+ },
}
for _, tt := range tests {
diff --git a/backend/internal/service/totp_service.go b/backend/internal/service/totp_service.go
index 5192fe3d..052739ed 100644
--- a/backend/internal/service/totp_service.go
+++ b/backend/internal/service/totp_service.go
@@ -58,9 +58,15 @@ type TotpSetupSession struct {
// TotpLoginSession represents a pending 2FA login session
type TotpLoginSession struct {
- UserID int64
- Email string
- TokenExpiry time.Time
+ UserID int64
+ Email string
+ TokenExpiry time.Time
+ PendingOAuthBind *PendingOAuthBindLoginSession `json:"pending_oauth_bind,omitempty"`
+}
+
+type PendingOAuthBindLoginSession struct {
+ PendingSessionToken string `json:"pending_session_token,omitempty"`
+ BrowserSessionKey string `json:"browser_session_key,omitempty"`
}
// TotpStatus represents the TOTP status for a user
@@ -397,6 +403,30 @@ func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string)
// CreateLoginSession creates a temporary login session for 2FA
func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) {
+ return s.createLoginSession(ctx, userID, email, nil)
+}
+
+// CreatePendingOAuthBindLoginSession creates a temporary 2FA session that will
+// finalize a pending OAuth bind after the TOTP code is verified.
+func (s *TotpService) CreatePendingOAuthBindLoginSession(
+ ctx context.Context,
+ userID int64,
+ email string,
+ pendingSessionToken string,
+ browserSessionKey string,
+) (string, error) {
+ return s.createLoginSession(ctx, userID, email, &PendingOAuthBindLoginSession{
+ PendingSessionToken: pendingSessionToken,
+ BrowserSessionKey: browserSessionKey,
+ })
+}
+
+func (s *TotpService) createLoginSession(
+ ctx context.Context,
+ userID int64,
+ email string,
+ pendingOAuthBind *PendingOAuthBindLoginSession,
+) (string, error) {
// Generate a random temp token
tempToken, err := generateRandomToken(32)
if err != nil {
@@ -404,9 +434,10 @@ func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, emai
}
session := &TotpLoginSession{
- UserID: userID,
- Email: email,
- TokenExpiry: time.Now().Add(totpLoginTTL),
+ UserID: userID,
+ Email: email,
+ TokenExpiry: time.Now().Add(totpLoginTTL),
+ PendingOAuthBind: pendingOAuthBind,
}
if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil {
diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go
index a0444d52..ddf0e818 100644
--- a/backend/internal/service/upstream_response_limit.go
+++ b/backend/internal/service/upstream_response_limit.go
@@ -12,7 +12,9 @@ import (
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
-const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
+// defaultUpstreamResponseReadMaxBytes 源自 config.DefaultUpstreamResponseReadMaxBytes,
+// 仅在 cfg 为 nil 时作为兜底(测试或极端场景)。
+const defaultUpstreamResponseReadMaxBytes = config.DefaultUpstreamResponseReadMaxBytes
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go
index 59f8aa6b..f9833611 100644
--- a/backend/internal/service/user.go
+++ b/backend/internal/service/user.go
@@ -7,19 +7,31 @@ import (
)
type User struct {
- ID int64
- Email string
- Username string
- Notes string
- PasswordHash string
- Role string
- Balance float64
- Concurrency int
- Status string
- AllowedGroups []int64
- TokenVersion int64 // Incremented on password change to invalidate existing tokens
- CreatedAt time.Time
- UpdatedAt time.Time
+ ID int64
+ Email string
+ Username string
+ Notes string
+ AvatarURL string
+ AvatarSource string
+ AvatarMIME string
+ AvatarByteSize int
+ AvatarSHA256 string
+ PasswordHash string
+ Role string
+ Balance float64
+ Concurrency int
+ Status string
+ AllowedGroups []int64
+ TokenVersion int64 // Incremented on password change to invalidate existing tokens
+ // TokenVersionResolved indicates TokenVersion already contains the fingerprint-derived
+ // value expected in JWT claims and refresh-token state.
+ TokenVersionResolved bool
+ SignupSource string
+ LastLoginAt *time.Time
+ LastActiveAt *time.Time
+ LastUsedAt *time.Time
+ CreatedAt time.Time
+ UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
@@ -37,6 +49,15 @@ type User struct {
BalanceNotifyExtraEmails []NotifyEmailEntry
TotalRecharged float64
+ // RPMLimit 用户级每分钟请求数上限(0 = 不限制)。仅在所用分组未设置 rpm_limit
+ // 且该 (用户, 分组) 无 rpm_override 时作为全局兜底生效,计数键 rpm:u:{userID}:{min}。
+ RPMLimit int
+
+ // UserGroupRPMOverride 来自 auth cache snapshot 的 (user, group) RPM 覆盖值。
+ // nil = 该 API Key 对应的 (user, group) 无 override;非 nil 时 checkRPM 直接使用,
+ // 避免每请求查 DB。字段不持久化到数据库。
+ UserGroupRPMOverride *int
+
APIKeys []APIKey
Subscriptions []UserSubscription
}
diff --git a/backend/internal/service/user_group_rate.go b/backend/internal/service/user_group_rate.go
index 3d221a25..f069eb7e 100644
--- a/backend/internal/service/user_group_rate.go
+++ b/backend/internal/service/user_group_rate.go
@@ -2,14 +2,16 @@ package service
import "context"
-// UserGroupRateEntry 分组下用户专属倍率条目
+// UserGroupRateEntry 分组下用户专属倍率/RPM 条目。
+// RateMultiplier 与 RPMOverride 均为指针以支持"未设置"语义(NULL)。
type UserGroupRateEntry struct {
- UserID int64 `json:"user_id"`
- UserName string `json:"user_name"`
- UserEmail string `json:"user_email"`
- UserNotes string `json:"user_notes"`
- UserStatus string `json:"user_status"`
- RateMultiplier float64 `json:"rate_multiplier"`
+ UserID int64 `json:"user_id"`
+ UserName string `json:"user_name"`
+ UserEmail string `json:"user_email"`
+ UserNotes string `json:"user_notes"`
+ UserStatus string `json:"user_status"`
+ RateMultiplier *float64 `json:"rate_multiplier,omitempty"`
+ RPMOverride *int `json:"rpm_override,omitempty"`
}
// GroupRateMultiplierInput 批量设置分组倍率的输入条目
@@ -18,30 +20,44 @@ type GroupRateMultiplierInput struct {
RateMultiplier float64 `json:"rate_multiplier"`
}
-// UserGroupRateRepository 用户专属分组倍率仓储接口
-// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
+// GroupRPMOverrideInput 批量设置分组 RPM override 的输入条目。
+// RPMOverride 为 *int 以支持清除(nil)语义。
+type GroupRPMOverrideInput struct {
+ UserID int64 `json:"user_id"`
+ RPMOverride *int `json:"rpm_override"`
+}
+
+// UserGroupRateRepository 用户专属分组倍率/RPM 仓储接口。
+// 允许管理员为特定用户设置分组的专属计费倍率与 RPM 上限,覆盖分组默认值。
type UserGroupRateRepository interface {
- // GetByUserID 获取用户的所有专属分组倍率
- // 返回 map[groupID]rateMultiplier
+ // GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目)
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
- // GetByUserAndGroup 获取用户在特定分组的专属倍率
- // 如果未设置专属倍率,返回 nil
+ // GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil)
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
- // GetByGroupID 获取指定分组下所有用户的专属倍率
+ // GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
+ GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error)
+
+ // GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回)
GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
- // SyncUserGroupRates 同步用户的分组专属倍率
- // rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率
+ // SyncUserGroupRates 同步用户的分组专属倍率;nil 表示清空该分组的 rate_multiplier
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
- // SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组数据)
+ // SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组 rate 部分)
SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
- // DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
+ // SyncGroupRPMOverrides 批量同步分组的用户专属 RPM(替换整组 rpm_override 部分)。
+ // 条目中 RPMOverride 为 nil 时清空对应行的 rpm_override;非 nil 时 upsert。
+ SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
+
+ // ClearGroupRPMOverrides 清空指定分组的所有 rpm_override(整组 rpm 部分归 NULL)
+ ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
+
+ // DeleteByGroupID 删除指定分组的所有用户专属条目(分组删除时调用)
DeleteByGroupID(ctx context.Context, groupID int64) error
- // DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
+ // DeleteByUserID 删除指定用户的所有专属条目(用户删除时调用)
DeleteByUserID(ctx context.Context, userID int64) error
}
diff --git a/backend/internal/service/user_rpm_cache.go b/backend/internal/service/user_rpm_cache.go
new file mode 100644
index 00000000..b8857311
--- /dev/null
+++ b/backend/internal/service/user_rpm_cache.go
@@ -0,0 +1,25 @@
+package service
+
+import "context"
+
+// UserRPMCache 用户/分组级 RPM 计数器接口。
+//
+// 与账号级 RPMCache 的区别:
+// - RPMCache —— 按外部 AI provider 账号聚合(key: rpm:{accountID}:{min})。
+// - UserRPMCache —— 按用户或 (用户, 分组) 聚合,杜绝"同一用户创建多个 API Key 绕过 RPM"的路径。
+// key 形如 rpm:ug:{userID}:{groupID}:{min} 或 rpm:u:{userID}:{min}。
+type UserRPMCache interface {
+ // IncrementUserGroupRPM 原子递增 (user, group) 级分钟计数并返回最新值。
+ // 用于分组 rpm_limit 与 user-group rpm_override 两种命中分支。
+ IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error)
+
+ // IncrementUserRPM 原子递增用户级分钟计数并返回最新值。
+ // 用于用户全局 rpm_limit 兜底分支(分组未设且无 override 时)。
+ IncrementUserRPM(ctx context.Context, userID int64) (count int, err error)
+
+ // GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读,不递增)。
+ GetUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error)
+
+ // GetUserRPM 获取用户当前分钟已用 RPM(只读,不递增)。
+ GetUserRPM(ctx context.Context, userID int64) (count int, err error)
+}
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 3490e804..a7279e6a 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -1,30 +1,66 @@
package service
import (
+ "bytes"
"context"
+ "crypto/sha256"
"crypto/subtle"
+ "encoding/base64"
+ "encoding/hex"
"fmt"
- "log/slog"
- "strings"
- "time"
-
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "image"
+ "image/color"
+ stddraw "image/draw"
+ _ "image/gif"
+ "image/jpeg"
+ _ "image/png"
+ "log/slog"
+ "net/url"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ xdraw "golang.org/x/image/draw"
+ "golang.org/x/sync/singleflight"
)
var (
- ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
- ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
- ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
- ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
+ ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
+ ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
+ ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
+ ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
+ ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL")
+ ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller")
+ ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
+ ErrIdentityProviderInvalid = infraerrors.BadRequest("IDENTITY_PROVIDER_INVALID", "identity provider is invalid")
+ ErrIdentityRedirectInvalid = infraerrors.BadRequest("IDENTITY_REDIRECT_INVALID", "identity redirect path is invalid")
+ ErrIdentityUnbindLastMethod = infraerrors.Conflict(
+ "IDENTITY_UNBIND_LAST_METHOD",
+ "bind another sign-in method before unbinding this provider",
+ )
)
const (
- maxNotifyEmails = 3 // Maximum number of notification emails per user
+ maxNotifyEmails = 3 // Maximum number of notification emails per user
+ maxInlineAvatarBytes = 100 * 1024
+ targetAvatarBytes = 20 * 1024
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
notifyCodeUserRateWindow = 10 * time.Minute
+
+ defaultUserIdentityRedirect = "/settings/profile"
+ userLastActiveMinTouch = 10 * time.Minute
+ userLastActiveFailBackoff = 30 * time.Second
+)
+
+var (
+ avatarScaleSteps = []float64{1, 0.92, 0.84, 0.76, 0.68, 0.6, 0.52, 0.44, 0.36}
+ avatarQualitySteps = []int{88, 80, 72, 64, 56, 48, 40, 32}
)
// UserListFilters contains all filter options for listing users
@@ -47,9 +83,15 @@ type UserRepository interface {
GetFirstAdmin(ctx context.Context) (*User, error)
Update(ctx context.Context, user *User) error
Delete(ctx context.Context, id int64) error
+ GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error)
+ UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
+ DeleteUserAvatar(ctx context.Context, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
+ GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error)
+ GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error)
+ UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
@@ -60,6 +102,8 @@ type UserRepository interface {
AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error
// RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限
RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error
+ ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
+ UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error
// TOTP 双因素认证
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
@@ -67,15 +111,90 @@ type UserRepository interface {
DisableTotp(ctx context.Context, userID int64) error
}
+type UserAuthIdentityRecord struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+ VerifiedAt *time.Time
+ Issuer *string
+ Metadata map[string]any
+ CreatedAt time.Time
+ UpdatedAt time.Time
+}
+
+type UserIdentitySummary struct {
+ Provider string `json:"provider"`
+ Bound bool `json:"bound"`
+ BoundCount int `json:"bound_count"`
+ DisplayName string `json:"display_name,omitempty"`
+ AvatarURL string `json:"-"`
+ SubjectHint string `json:"subject_hint,omitempty"`
+ ProviderKey string `json:"provider_key,omitempty"`
+ VerifiedAt *time.Time `json:"verified_at,omitempty"`
+ BindStartPath string `json:"bind_start_path,omitempty"`
+ CanBind bool `json:"can_bind"`
+ CanUnbind bool `json:"can_unbind"`
+ NoteKey string `json:"note_key,omitempty"`
+ Note string `json:"note,omitempty"`
+}
+
+type UserIdentitySummarySet struct {
+ Email UserIdentitySummary `json:"email"`
+ LinuxDo UserIdentitySummary `json:"linuxdo"`
+ OIDC UserIdentitySummary `json:"oidc"`
+ WeChat UserIdentitySummary `json:"wechat"`
+}
+
+type StartUserIdentityBindingRequest struct {
+ Provider string
+ RedirectTo string
+}
+
+type StartUserIdentityBindingResult struct {
+ Provider string `json:"provider"`
+ AuthorizeURL string `json:"authorize_url"`
+ Method string `json:"method"`
+ UseBrowserRedirect bool `json:"use_browser_redirect"`
+}
+
+const (
+ userIdentityNoteEmailManagedFromProfile = "profile.authBindings.notes.emailManagedFromProfile"
+ userIdentityNoteCanUnbind = "profile.authBindings.notes.canUnbind"
+ userIdentityNoteBindAnotherBeforeUnbind = "profile.authBindings.notes.bindAnotherBeforeUnbind"
+)
+
// UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct {
Email *string `json:"email"`
Username *string `json:"username"`
+ AvatarURL *string `json:"avatar_url"`
Concurrency *int `json:"concurrency"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
+type UserAvatar struct {
+ StorageProvider string
+ StorageKey string
+ URL string
+ ContentType string
+ ByteSize int
+ SHA256 string
+}
+
+type UpsertUserAvatarInput struct {
+ StorageProvider string
+ StorageKey string
+ URL string
+ ContentType string
+ ByteSize int
+ SHA256 string
+}
+
+type userProfileIdentityTxRunner interface {
+ WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error
+}
+
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@@ -88,6 +207,8 @@ type UserService struct {
settingRepo SettingRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCache BillingCache
+ lastActiveTouchL1 sync.Map
+ lastActiveTouchSF singleflight.Group
}
// NewUserService 创建用户服务实例
@@ -115,14 +236,176 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
+ normalizeLoadedUserTokenVersion(user)
+ if err := s.hydrateUserAvatar(ctx, user); err != nil {
+ return nil, fmt.Errorf("get user avatar: %w", err)
+ }
return user, nil
}
+func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID int64, user *User) (UserIdentitySummarySet, error) {
+ if user == nil {
+ var err error
+ user, err = s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return UserIdentitySummarySet{}, fmt.Errorf("get user: %w", err)
+ }
+ }
+
+ records, err := s.listUserAuthIdentities(ctx, userID)
+ if err != nil {
+ return UserIdentitySummarySet{}, err
+ }
+
+ summaries := UserIdentitySummarySet{
+ Email: s.buildEmailIdentitySummary(user, records),
+ LinuxDo: s.buildProviderIdentitySummary("linuxdo", user, records),
+ OIDC: s.buildProviderIdentitySummary("oidc", user, records),
+ WeChat: s.buildProviderIdentitySummary("wechat", user, records),
+ }
+
+ s.applyExplicitProviderAvailability(ctx, &summaries)
+ return summaries, nil
+}
+
+func (s *UserService) applyExplicitProviderAvailability(ctx context.Context, summaries *UserIdentitySummarySet) {
+ if s == nil || summaries == nil || s.settingRepo == nil {
+ return
+ }
+
+ settings, err := s.settingRepo.GetMultiple(ctx, []string{
+ SettingKeyLinuxDoConnectEnabled,
+ SettingKeyOIDCConnectEnabled,
+ SettingKeyWeChatConnectEnabled,
+ SettingKeyWeChatConnectOpenEnabled,
+ SettingKeyWeChatConnectMPEnabled,
+ SettingKeyWeChatConnectMobileEnabled,
+ SettingKeyWeChatConnectMode,
+ })
+ if err != nil {
+ return
+ }
+
+ if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" {
+ disableIdentityBindAction(&summaries.LinuxDo)
+ }
+ if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" {
+ disableIdentityBindAction(&summaries.OIDC)
+ }
+ if raw, ok := settings[SettingKeyWeChatConnectEnabled]; ok && strings.TrimSpace(raw) != "" {
+ if raw != "true" {
+ disableIdentityBindAction(&summaries.WeChat)
+ return
+ }
+ openEnabled, mpEnabled, _ := parseWeChatConnectCapabilitySettings(settings, true, settings[SettingKeyWeChatConnectMode])
+ if !openEnabled && !mpEnabled {
+ disableIdentityBindAction(&summaries.WeChat)
+ }
+ }
+}
+
+func disableIdentityBindAction(summary *UserIdentitySummary) {
+ if summary == nil || summary.Bound {
+ return
+ }
+ summary.CanBind = false
+ summary.BindStartPath = ""
+}
+
+func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUserIdentityBindingRequest) (*StartUserIdentityBindingResult, error) {
+ provider := normalizeUserIdentityProvider(req.Provider)
+ if provider == "" {
+ return nil, ErrIdentityProviderInvalid
+ }
+
+ authorizeURL, err := buildUserIdentityBindAuthorizeURL(provider, req.RedirectTo)
+ if err != nil {
+ return nil, err
+ }
+
+ return &StartUserIdentityBindingResult{
+ Provider: provider,
+ AuthorizeURL: authorizeURL,
+ Method: "GET",
+ UseBrowserRedirect: true,
+ }, nil
+}
+
+func (s *UserService) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) (*User, error) {
+ user, _, err := s.UnbindUserAuthProviderWithResult(ctx, userID, provider)
+ return user, err
+}
+
+func (s *UserService) UnbindUserAuthProviderWithResult(ctx context.Context, userID int64, provider string) (*User, bool, error) {
+ provider = normalizeUserIdentityProvider(provider)
+ if provider == "" || provider == "email" {
+ return nil, false, ErrIdentityProviderInvalid
+ }
+
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return nil, false, fmt.Errorf("get user: %w", err)
+ }
+
+ records, err := s.listUserAuthIdentities(ctx, userID)
+ if err != nil {
+ return nil, false, err
+ }
+ if len(filterUserAuthIdentities(records, provider)) == 0 {
+ return user, false, nil
+ }
+ if !s.canUnbindProvider(provider, user, records) {
+ return nil, false, ErrIdentityUnbindLastMethod
+ }
+
+ if err := s.userRepo.UnbindUserAuthProvider(ctx, userID, provider); err != nil {
+ return nil, false, err
+ }
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ }
+
+ updatedUser, err := s.GetProfile(ctx, userID)
+ if err != nil {
+ return nil, false, err
+ }
+ return updatedUser, true, nil
+}
+
// UpdateProfile 更新用户资料
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
+ if txRunner, ok := s.userRepo.(userProfileIdentityTxRunner); ok {
+ var (
+ updated *User
+ oldConcurrency int
+ )
+ if err := txRunner.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ var err error
+ updated, oldConcurrency, err = s.updateProfile(txCtx, userID, req)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+ if s.authCacheInvalidator != nil && updated != nil && updated.Concurrency != oldConcurrency {
+ s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ }
+ return updated, nil
+ }
+
+ updated, oldConcurrency, err := s.updateProfile(ctx, userID, req)
+ if err != nil {
+ return nil, err
+ }
+ if s.authCacheInvalidator != nil && updated.Concurrency != oldConcurrency {
+ s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ }
+ return updated, nil
+}
+
+func (s *UserService) updateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, int, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
- return nil, fmt.Errorf("get user: %w", err)
+ return nil, 0, fmt.Errorf("get user: %w", err)
}
oldConcurrency := user.Concurrency
@@ -131,10 +414,10 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
// 检查新邮箱是否已被使用
exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email)
if err != nil {
- return nil, fmt.Errorf("check email exists: %w", err)
+ return nil, oldConcurrency, fmt.Errorf("check email exists: %w", err)
}
if exists && *req.Email != user.Email {
- return nil, ErrEmailExists
+ return nil, oldConcurrency, ErrEmailExists
}
user.Email = *req.Email
}
@@ -143,6 +426,14 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Username = *req.Username
}
+ if req.AvatarURL != nil {
+ avatar, err := s.SetAvatar(ctx, userID, *req.AvatarURL)
+ if err != nil {
+ return nil, oldConcurrency, err
+ }
+ applyUserAvatar(user, avatar)
+ }
+
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}
@@ -159,13 +450,465 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
}
if err := s.userRepo.Update(ctx, user); err != nil {
- return nil, fmt.Errorf("update user: %w", err)
- }
- if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
- s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ return nil, oldConcurrency, fmt.Errorf("update user: %w", err)
}
- return user, nil
+ return user, oldConcurrency, nil
+}
+
+func (s *UserService) SetAvatar(ctx context.Context, userID int64, raw string) (*UserAvatar, error) {
+ avatarValue := strings.TrimSpace(raw)
+ if avatarValue == "" {
+ if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil {
+ return nil, fmt.Errorf("delete avatar: %w", err)
+ }
+ return nil, nil
+ }
+
+ avatarInput, err := normalizeUserAvatarInput(avatarValue)
+ if err != nil {
+ return nil, err
+ }
+
+ avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput)
+ if err != nil {
+ return nil, fmt.Errorf("upsert avatar: %w", err)
+ }
+ return avatar, nil
+}
+
+func applyUserAvatar(user *User, avatar *UserAvatar) {
+ if user == nil {
+ return
+ }
+ if avatar == nil {
+ user.AvatarURL = ""
+ user.AvatarSource = ""
+ user.AvatarMIME = ""
+ user.AvatarByteSize = 0
+ user.AvatarSHA256 = ""
+ return
+ }
+
+ user.AvatarURL = avatar.URL
+ user.AvatarSource = avatar.StorageProvider
+ user.AvatarMIME = avatar.ContentType
+ user.AvatarByteSize = avatar.ByteSize
+ user.AvatarSHA256 = avatar.SHA256
+}
+
+func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if strings.HasPrefix(raw, "data:") {
+ return normalizeInlineUserAvatarInput(raw)
+ }
+
+ parsed, err := url.Parse(raw)
+ if err != nil || parsed == nil {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if strings.TrimSpace(parsed.Host) == "" {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+
+ return UpsertUserAvatarInput{
+ StorageProvider: "remote_url",
+ URL: raw,
+ }, nil
+}
+
+func ValidateUserAvatar(raw string) error {
+ _, err := normalizeUserAvatarInput(raw)
+ return err
+}
+
+func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
+ body := strings.TrimPrefix(raw, "data:")
+ meta, encoded, ok := strings.Cut(body, ",")
+ if !ok {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ meta = strings.TrimSpace(meta)
+ encoded = strings.TrimSpace(encoded)
+ if !strings.HasSuffix(strings.ToLower(meta), ";base64") {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+
+ contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")])
+ if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") {
+ return UpsertUserAvatarInput{}, ErrAvatarNotImage
+ }
+
+ decoded, err := base64.StdEncoding.DecodeString(encoded)
+ if err != nil {
+ return UpsertUserAvatarInput{}, ErrAvatarInvalid
+ }
+ if len(decoded) > maxInlineAvatarBytes {
+ return UpsertUserAvatarInput{}, ErrAvatarTooLarge
+ }
+
+ if len(decoded) > targetAvatarBytes {
+ decoded, contentType, err = compressInlineAvatar(decoded)
+ if err != nil {
+ return UpsertUserAvatarInput{}, err
+ }
+ raw = "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(decoded)
+ }
+
+ sum := sha256.Sum256(decoded)
+ return UpsertUserAvatarInput{
+ StorageProvider: "inline",
+ URL: raw,
+ ContentType: contentType,
+ ByteSize: len(decoded),
+ SHA256: hex.EncodeToString(sum[:]),
+ }, nil
+}
+
+func compressInlineAvatar(decoded []byte) ([]byte, string, error) {
+ src, _, err := image.Decode(bytes.NewReader(decoded))
+ if err != nil {
+ return nil, "", ErrAvatarInvalid
+ }
+
+ srcBounds := src.Bounds()
+ if srcBounds.Empty() {
+ return nil, "", ErrAvatarInvalid
+ }
+
+ for _, scale := range avatarScaleSteps {
+ width := max(1, int(float64(srcBounds.Dx())*scale))
+ height := max(1, int(float64(srcBounds.Dy())*scale))
+ dst := image.NewRGBA(image.Rect(0, 0, width, height))
+ stddraw.Draw(dst, dst.Bounds(), &image.Uniform{C: color.White}, image.Point{}, stddraw.Src)
+ xdraw.CatmullRom.Scale(dst, dst.Bounds(), src, srcBounds, stddraw.Over, nil)
+
+ for _, quality := range avatarQualitySteps {
+ var buf bytes.Buffer
+ if err := jpeg.Encode(&buf, dst, &jpeg.Options{Quality: quality}); err != nil {
+ return nil, "", ErrAvatarInvalid
+ }
+ if buf.Len() <= targetAvatarBytes {
+ return buf.Bytes(), "image/jpeg", nil
+ }
+ }
+ }
+
+ return nil, "", ErrAvatarTooLarge
+}
+
+func (s *UserService) buildEmailIdentitySummary(user *User, records []UserAuthIdentityRecord) UserIdentitySummary {
+ summary := UserIdentitySummary{
+ Provider: "email",
+ CanBind: false,
+ CanUnbind: false,
+ NoteKey: userIdentityNoteEmailManagedFromProfile,
+ Note: "Primary account email is managed from the profile form.",
+ }
+ if user == nil {
+ return summary
+ }
+
+ filtered := filterUserAuthIdentities(records, "email")
+ if len(filtered) > 0 {
+ primary := selectPrimaryUserAuthIdentity(filtered)
+ email := strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "email"))
+ if email == "" {
+ email = strings.TrimSpace(primary.ProviderSubject)
+ }
+ if email == "" || isReservedEmail(email) {
+ email = strings.TrimSpace(user.Email)
+ }
+ if email == "" || isReservedEmail(email) {
+ email = strings.TrimSpace(primary.ProviderKey)
+ }
+
+ summary.Bound = true
+ summary.BoundCount = len(filtered)
+ summary.DisplayName = email
+ summary.SubjectHint = maskEmailIdentity(email)
+ summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
+ summary.VerifiedAt = primary.VerifiedAt
+ return summary
+ }
+
+ // Compatibility fallback for legacy normal-email users that predate auth_identities backfill.
+ email := strings.TrimSpace(user.Email)
+ if email == "" || isReservedEmail(email) {
+ return summary
+ }
+ summary.Bound = true
+ summary.BoundCount = 1
+ summary.DisplayName = email
+ summary.SubjectHint = maskEmailIdentity(email)
+ summary.ProviderKey = "email"
+ return summary
+}
+
+func (s *UserService) buildProviderIdentitySummary(provider string, user *User, records []UserAuthIdentityRecord) UserIdentitySummary {
+ summary := UserIdentitySummary{
+ Provider: provider,
+ CanUnbind: false,
+ }
+ filtered := filterUserAuthIdentities(records, provider)
+ if len(filtered) == 0 {
+ summary.CanBind = true
+ bindStartPath, err := buildUserIdentityBindAuthorizeURL(provider, "")
+ if err == nil {
+ summary.BindStartPath = bindStartPath
+ }
+ return summary
+ }
+
+ primary := selectPrimaryUserAuthIdentity(filtered)
+ summary.Bound = true
+ summary.BoundCount = len(filtered)
+ summary.DisplayName = userAuthIdentityDisplayName(primary)
+ summary.AvatarURL = strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "avatar_url", "suggested_avatar_url", "headimgurl"))
+ summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject)
+ summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
+ summary.VerifiedAt = primary.VerifiedAt
+ summary.CanUnbind = s.canUnbindProvider(provider, user, records)
+ if summary.CanUnbind {
+ summary.NoteKey = userIdentityNoteCanUnbind
+ summary.Note = "You can unbind this sign-in method."
+ } else {
+ summary.NoteKey = userIdentityNoteBindAnotherBeforeUnbind
+ summary.Note = "Bind another sign-in method before unbinding."
+ }
+ return summary
+}
+
+func (s *UserService) canUnbindProvider(provider string, user *User, records []UserAuthIdentityRecord) bool {
+ if provider == "" || provider == "email" || len(filterUserAuthIdentities(records, provider)) == 0 {
+ return false
+ }
+
+ if s.canUseEmailAsSignInMethod(user, records) {
+ return true
+ }
+
+ for _, candidate := range []string{"linuxdo", "oidc", "wechat"} {
+ if candidate == provider {
+ continue
+ }
+ if len(filterUserAuthIdentities(records, candidate)) > 0 {
+ return true
+ }
+ }
+
+ return false
+}
+
+func (s *UserService) canUseEmailAsSignInMethod(user *User, records []UserAuthIdentityRecord) bool {
+ if user == nil {
+ return false
+ }
+
+ email := strings.ToLower(strings.TrimSpace(user.Email))
+ if email == "" || isReservedEmail(email) {
+ return false
+ }
+
+ if emailSignupSourceAllowsLogin(user.SignupSource) {
+ return true
+ }
+
+ for _, record := range filterUserAuthIdentities(records, "email") {
+ if emailIdentitySupportsSignIn(record) {
+ return true
+ }
+ }
+
+ return false
+}
+
+func emailSignupSourceAllowsLogin(signupSource string) bool {
+ signupSource = strings.ToLower(strings.TrimSpace(signupSource))
+ return signupSource == "" || signupSource == "email"
+}
+
+func emailIdentitySupportsSignIn(record UserAuthIdentityRecord) bool {
+ source := strings.TrimSpace(firstStringIdentityValue(record.Metadata, "source"))
+ switch source {
+ case "auth_service_email_bind", "auth_service_login_backfill", "auth_service_dual_write":
+ return true
+ default:
+ return false
+ }
+}
+
+func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
+ if userID <= 0 || s == nil || s.userRepo == nil {
+ return nil, nil
+ }
+ return s.userRepo.ListUserAuthIdentities(ctx, userID)
+}
+
+func buildUserIdentityBindAuthorizeURL(provider, redirectTo string) (string, error) {
+ provider = normalizeUserIdentityProvider(provider)
+ if provider == "" || provider == "email" {
+ return "", ErrIdentityProviderInvalid
+ }
+
+ redirectTo, err := normalizeUserIdentityRedirect(redirectTo)
+ if err != nil {
+ return "", err
+ }
+
+ path := ""
+ switch provider {
+ case "linuxdo":
+ path = "/api/v1/auth/oauth/linuxdo/bind/start"
+ case "oidc":
+ path = "/api/v1/auth/oauth/oidc/bind/start"
+ case "wechat":
+ path = "/api/v1/auth/oauth/wechat/bind/start"
+ default:
+ return "", ErrIdentityProviderInvalid
+ }
+
+ query := url.Values{}
+ query.Set("redirect", redirectTo)
+ query.Set("intent", "bind_current_user")
+ return path + "?" + query.Encode(), nil
+}
+
+func normalizeUserIdentityProvider(provider string) string {
+ switch strings.ToLower(strings.TrimSpace(provider)) {
+ case "linuxdo":
+ return "linuxdo"
+ case "oidc":
+ return "oidc"
+ case "wechat":
+ return "wechat"
+ case "email":
+ return "email"
+ default:
+ return ""
+ }
+}
+
+func normalizeUserIdentityRedirect(raw string) (string, error) {
+ redirect := strings.TrimSpace(raw)
+ if redirect == "" {
+ return defaultUserIdentityRedirect, nil
+ }
+ if len(redirect) > 2048 || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
+ return "", ErrIdentityRedirectInvalid
+ }
+ return redirect, nil
+}
+
+func filterUserAuthIdentities(records []UserAuthIdentityRecord, provider string) []UserAuthIdentityRecord {
+ if len(records) == 0 {
+ return nil
+ }
+ filtered := make([]UserAuthIdentityRecord, 0, len(records))
+ for _, record := range records {
+ if strings.EqualFold(strings.TrimSpace(record.ProviderType), provider) {
+ filtered = append(filtered, record)
+ }
+ }
+ return filtered
+}
+
+func selectPrimaryUserAuthIdentity(records []UserAuthIdentityRecord) UserAuthIdentityRecord {
+ if len(records) == 0 {
+ return UserAuthIdentityRecord{}
+ }
+ sort.SliceStable(records, func(i, j int) bool {
+ left := userAuthIdentitySortTime(records[i])
+ right := userAuthIdentitySortTime(records[j])
+ if !left.Equal(right) {
+ return left.After(right)
+ }
+ return records[i].ProviderKey < records[j].ProviderKey
+ })
+ return records[0]
+}
+
+func userAuthIdentitySortTime(record UserAuthIdentityRecord) time.Time {
+ if record.VerifiedAt != nil && !record.VerifiedAt.IsZero() {
+ return record.VerifiedAt.UTC()
+ }
+ if !record.UpdatedAt.IsZero() {
+ return record.UpdatedAt.UTC()
+ }
+ if !record.CreatedAt.IsZero() {
+ return record.CreatedAt.UTC()
+ }
+ return time.Time{}
+}
+
+func userAuthIdentityDisplayName(record UserAuthIdentityRecord) string {
+ if displayName := firstStringIdentityValue(record.Metadata,
+ "display_name",
+ "suggested_display_name",
+ "username",
+ "name",
+ "nickname",
+ "email",
+ ); displayName != "" {
+ return displayName
+ }
+ if subject := strings.TrimSpace(record.ProviderSubject); subject != "" {
+ return subject
+ }
+ return strings.TrimSpace(record.ProviderType)
+}
+
+func firstStringIdentityValue(values map[string]any, keys ...string) string {
+ for _, key := range keys {
+ raw, ok := values[key]
+ if !ok {
+ continue
+ }
+ switch value := raw.(type) {
+ case string:
+ if trimmed := strings.TrimSpace(value); trimmed != "" {
+ return trimmed
+ }
+ case fmt.Stringer:
+ if trimmed := strings.TrimSpace(value.String()); trimmed != "" {
+ return trimmed
+ }
+ }
+ }
+ return ""
+}
+
+func maskEmailIdentity(email string) string {
+ local, domain, ok := strings.Cut(strings.TrimSpace(email), "@")
+ if !ok || local == "" || domain == "" {
+ return maskOpaqueIdentity(email)
+ }
+ runes := []rune(local)
+ if len(runes) == 1 {
+ return string(runes[0]) + "***@" + domain
+ }
+ return string(runes[0]) + "***" + string(runes[len(runes)-1]) + "@" + domain
+}
+
+func maskOpaqueIdentity(value string) string {
+ value = strings.TrimSpace(value)
+ runes := []rune(value)
+ switch {
+ case len(runes) == 0:
+ return ""
+ case len(runes) <= 4:
+ return string(runes[0]) + "***"
+ case len(runes) <= 8:
+ return string(runes[:2]) + "***" + string(runes[len(runes)-1:])
+ default:
+ return string(runes[:3]) + "***" + string(runes[len(runes)-3:])
+ }
}
// ChangePassword 修改密码
@@ -202,9 +945,94 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
+ normalizeLoadedUserTokenVersion(user)
+ if err := s.hydrateUserAvatar(ctx, user); err != nil {
+ return nil, fmt.Errorf("get user avatar: %w", err)
+ }
return user, nil
}
+func normalizeLoadedUserTokenVersion(user *User) {
+ if user == nil || user.TokenVersionResolved {
+ return
+ }
+ user.TokenVersion = resolvedTokenVersion(user)
+ user.TokenVersionResolved = true
+}
+
+// TouchLastActive 通过防抖更新 users.last_active_at,减少鉴权热路径写放大。
+// 该操作为尽力而为,不应中断正常请求。
+func (s *UserService) TouchLastActive(ctx context.Context, userID int64) {
+ if s == nil || s.userRepo == nil || userID <= 0 {
+ return
+ }
+
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ slog.Debug("skip touch user last active after load failure", "user_id", userID, "error", err)
+ return
+ }
+ s.TouchLastActiveForUser(ctx, user)
+}
+
+// TouchLastActiveForUser 使用已加载的用户信息更新 last_active_at,避免重复读取数据库。
+func (s *UserService) TouchLastActiveForUser(ctx context.Context, user *User) {
+ if s == nil || s.userRepo == nil || user == nil || user.ID <= 0 {
+ return
+ }
+
+ now := time.Now()
+ if userLastActiveFresh(user.LastActiveAt, now) {
+ return
+ }
+ if v, ok := s.lastActiveTouchL1.Load(user.ID); ok {
+ if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) {
+ return
+ }
+ }
+
+ _, err, _ := s.lastActiveTouchSF.Do(strconv.FormatInt(user.ID, 10), func() (any, error) {
+ latest := time.Now()
+ if v, ok := s.lastActiveTouchL1.Load(user.ID); ok {
+ if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) {
+ return nil, nil
+ }
+ }
+ if userLastActiveFresh(user.LastActiveAt, latest) {
+ return nil, nil
+ }
+ if err := s.userRepo.UpdateUserLastActiveAt(ctx, user.ID, latest); err != nil {
+ s.lastActiveTouchL1.Store(user.ID, latest.Add(userLastActiveFailBackoff))
+ return nil, fmt.Errorf("touch user last active: %w", err)
+ }
+ s.lastActiveTouchL1.Store(user.ID, latest.Add(userLastActiveMinTouch))
+ return nil, nil
+ })
+ if err != nil {
+ slog.Warn("touch user last active failed", "user_id", user.ID, "error", err)
+ }
+}
+
+func userLastActiveFresh(lastActiveAt *time.Time, now time.Time) bool {
+ if lastActiveAt == nil {
+ return false
+ }
+ return now.Before(lastActiveAt.Add(userLastActiveMinTouch))
+}
+
+func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error {
+ if s == nil || s.userRepo == nil || user == nil || user.ID == 0 {
+ return nil
+ }
+
+ avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID)
+ if err != nil {
+ return err
+ }
+ applyUserAvatar(user, avatar)
+ return nil
+}
+
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)
diff --git a/backend/internal/service/user_service_email_identity_sync_test.go b/backend/internal/service/user_service_email_identity_sync_test.go
new file mode 100644
index 00000000..702b3b1a
--- /dev/null
+++ b/backend/internal/service/user_service_email_identity_sync_test.go
@@ -0,0 +1,34 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestUpdateProfile_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
+ repo := &emailSyncRepoStub{
+ user: &User{
+ ID: 19,
+ Email: "profile-before@example.com",
+ Username: "tester",
+ Concurrency: 2,
+ },
+ replaceErr: context.DeadlineExceeded,
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ newEmail := "profile-after@example.com"
+ updated, err := svc.UpdateProfile(context.Background(), 19, UpdateProfileRequest{
+ Email: &newEmail,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+ require.Equal(t, newEmail, updated.Email)
+ require.Equal(t, 1, repo.updateCalls)
+ require.Empty(t, repo.replaceCalls)
+ require.Empty(t, repo.ensureCalls)
+}
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index a998d5f4..ff55c2a5 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -3,8 +3,14 @@
package service
import (
+ "bytes"
"context"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/hex"
"errors"
+ "image"
+ "image/png"
"sync"
"sync/atomic"
"testing"
@@ -17,16 +23,159 @@ import (
// --- mock: UserRepository ---
type mockUserRepo struct {
- updateBalanceErr error
- updateBalanceFn func(ctx context.Context, id int64, amount float64) error
+ updateBalanceErr error
+ updateBalanceFn func(ctx context.Context, id int64, amount float64) error
+ getByIDUser *User
+ getByIDErr error
+ identities []UserAuthIdentityRecord
+ unbindIdentityErr error
+ unboundProviders []string
+ updateLastActiveErr error
+ updateLastActiveUserIDs []int64
+ updateLastActiveAt []time.Time
+ updateFn func(ctx context.Context, user *User) error
+ updateCalls int
+ upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
+ upsertAvatarArgs []UpsertUserAvatarInput
+ deleteAvatarFn func(ctx context.Context, userID int64) error
+ deleteAvatarIDs []int64
+ getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
+ txCalls int
}
-func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
-func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
+type mockUserRepoTxKey struct{}
+
+type mockUserRepoTxState struct {
+ getByIDUser *User
+ upsertAvatarArgs []UpsertUserAvatarInput
+ deleteAvatarIDs []int64
+}
+
+type mockUserSettingRepo struct {
+ values map[string]string
+}
+
+func (m *mockUserSettingRepo) Get(context.Context, string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (m *mockUserSettingRepo) GetValue(context.Context, string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (m *mockUserSettingRepo) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (m *mockUserSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := m.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (m *mockUserSettingRepo) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (m *mockUserSettingRepo) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (m *mockUserSettingRepo) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
+func (m *mockUserRepo) GetByID(ctx context.Context, _ int64) (*User, error) {
+ if m.getByIDErr != nil {
+ return nil, m.getByIDErr
+ }
+ if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil && txState.getByIDUser != nil {
+ cloned := *txState.getByIDUser
+ return &cloned, nil
+ }
+ if m.getByIDUser != nil {
+ cloned := *m.getByIDUser
+ return &cloned, nil
+ }
+ return &User{}, nil
+}
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
-func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
-func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
+func (m *mockUserRepo) Update(ctx context.Context, user *User) error {
+ m.updateCalls++
+ if m.updateFn != nil {
+ return m.updateFn(ctx, user)
+ }
+ return nil
+}
+func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
+func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
+ if m.getAvatarFn != nil {
+ return m.getAvatarFn(ctx, userID)
+ }
+ return nil, nil
+}
+func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
+ if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil {
+ txState.upsertAvatarArgs = append(txState.upsertAvatarArgs, input)
+ if txState.getByIDUser != nil {
+ txState.getByIDUser.AvatarURL = input.URL
+ txState.getByIDUser.AvatarSource = input.StorageProvider
+ txState.getByIDUser.AvatarMIME = input.ContentType
+ txState.getByIDUser.AvatarByteSize = input.ByteSize
+ txState.getByIDUser.AvatarSHA256 = input.SHA256
+ }
+ if m.upsertAvatarFn != nil {
+ return m.upsertAvatarFn(ctx, userID, input)
+ }
+ return &UserAvatar{
+ StorageProvider: input.StorageProvider,
+ StorageKey: input.StorageKey,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+ }
+ m.upsertAvatarArgs = append(m.upsertAvatarArgs, input)
+ if m.upsertAvatarFn != nil {
+ return m.upsertAvatarFn(ctx, userID, input)
+ }
+ return &UserAvatar{
+ StorageProvider: input.StorageProvider,
+ StorageKey: input.StorageKey,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+}
+func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ if txState, _ := ctx.Value(mockUserRepoTxKey{}).(*mockUserRepoTxState); txState != nil {
+ txState.deleteAvatarIDs = append(txState.deleteAvatarIDs, userID)
+ if txState.getByIDUser != nil {
+ txState.getByIDUser.AvatarURL = ""
+ txState.getByIDUser.AvatarSource = ""
+ txState.getByIDUser.AvatarMIME = ""
+ txState.getByIDUser.AvatarByteSize = 0
+ txState.getByIDUser.AvatarSHA256 = ""
+ }
+ if m.deleteAvatarFn != nil {
+ return m.deleteAvatarFn(ctx, userID)
+ }
+ return nil
+ }
+ m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID)
+ if m.deleteAvatarFn != nil {
+ return m.deleteAvatarFn(ctx, userID)
+ }
+ return nil
+}
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
@@ -39,6 +188,11 @@ func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float
}
return m.updateBalanceErr
}
+func (m *mockUserRepo) UpdateUserLastActiveAt(_ context.Context, userID int64, activeAt time.Time) error {
+ m.updateLastActiveUserIDs = append(m.updateLastActiveUserIDs, userID)
+ m.updateLastActiveAt = append(m.updateLastActiveAt, activeAt)
+ return m.updateLastActiveErr
+}
func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil }
func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
@@ -46,12 +200,58 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
return 0, nil
}
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
-func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
-func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
-func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
+func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ out := make([]UserAuthIdentityRecord, len(m.identities))
+ copy(out, m.identities)
+ return out, nil
+}
+func (m *mockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+func (m *mockUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
+func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
+func (m *mockUserRepo) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error {
+ if m.unbindIdentityErr != nil {
+ return m.unbindIdentityErr
+ }
+ m.unboundProviders = append(m.unboundProviders, provider)
+ filtered := m.identities[:0]
+ for _, identity := range m.identities {
+ if identity.ProviderType == provider {
+ continue
+ }
+ filtered = append(filtered, identity)
+ }
+ m.identities = append([]UserAuthIdentityRecord(nil), filtered...)
+ return nil
+}
+
+func (m *mockUserRepo) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ m.txCalls++
+ txState := &mockUserRepoTxState{
+ upsertAvatarArgs: append([]UpsertUserAvatarInput(nil), m.upsertAvatarArgs...),
+ deleteAvatarIDs: append([]int64(nil), m.deleteAvatarIDs...),
+ }
+ if m.getByIDUser != nil {
+ userCopy := *m.getByIDUser
+ txState.getByIDUser = &userCopy
+ }
+ err := fn(context.WithValue(ctx, mockUserRepoTxKey{}, txState))
+ if err != nil {
+ return err
+ }
+ m.getByIDUser = txState.getByIDUser
+ m.upsertAvatarArgs = txState.upsertAvatarArgs
+ m.deleteAvatarIDs = txState.deleteAvatarIDs
+ return nil
+}
// --- mock: APIKeyAuthCacheInvalidator ---
@@ -132,6 +332,225 @@ func TestUpdateBalance_Success(t *testing.T) {
require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存")
}
+func TestGetProfileIdentitySummaries_AllowsUnbindWhenAnotherLoginMethodRemains(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 7,
+ Email: "alice@example.com",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "alice@example.com",
+ },
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-123456",
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ },
+ },
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 7, repo.getByIDUser)
+
+ require.NoError(t, err)
+ require.True(t, summaries.LinuxDo.Bound)
+ require.True(t, summaries.LinuxDo.CanUnbind)
+ require.Equal(t, "linuxdo-handle", summaries.LinuxDo.DisplayName)
+ require.NotEmpty(t, summaries.LinuxDo.SubjectHint)
+}
+
+func TestUnbindUserAuthProviderRejectsLastRemainingLoginMethod(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 9,
+ Email: "only-user@linuxdo-connect.invalid",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-only-subject",
+ },
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ _, err := svc.UnbindUserAuthProvider(context.Background(), 9, "linuxdo")
+
+ require.ErrorIs(t, err, ErrIdentityUnbindLastMethod)
+ require.Empty(t, repo.unboundProviders)
+}
+
+func TestGetProfileIdentitySummaries_DoesNotTreatOAuthOnlyCompatEmailAsAlternativeLoginMethod(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 10,
+ Email: "oauth-only@example.com",
+ SignupSource: "oidc",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example.com",
+ ProviderSubject: "oidc-only-subject",
+ },
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 10, repo.getByIDUser)
+
+ require.NoError(t, err)
+ require.False(t, summaries.OIDC.CanUnbind)
+
+ _, err = svc.UnbindUserAuthProvider(context.Background(), 10, "oidc")
+ require.ErrorIs(t, err, ErrIdentityUnbindLastMethod)
+ require.Empty(t, repo.unboundProviders)
+}
+
+func TestGetProfileIdentitySummaries_DoesNotTreatCompatBackfilledEmailIdentityAsAlternativeLoginMethod(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 11,
+ Email: "oauth-only@example.com",
+ SignupSource: "wechat",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "oauth-only@example.com",
+ Metadata: map[string]any{
+ "backfill_source": "users.email",
+ "migration": "109_auth_identity_compat_backfill",
+ },
+ },
+ {
+ ProviderType: "wechat",
+ ProviderKey: "wechat",
+ ProviderSubject: "wechat-only-subject",
+ },
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 11, repo.getByIDUser)
+
+ require.NoError(t, err)
+ require.True(t, summaries.Email.Bound)
+ require.False(t, summaries.WeChat.CanUnbind)
+
+ _, err = svc.UnbindUserAuthProvider(context.Background(), 11, "wechat")
+ require.ErrorIs(t, err, ErrIdentityUnbindLastMethod)
+ require.Empty(t, repo.unboundProviders)
+}
+
+func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 12,
+ Email: "alice@example.com",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "alice@example.com",
+ },
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-12",
+ },
+ },
+ }
+ invalidator := &mockAuthCacheInvalidator{}
+ svc := NewUserService(repo, nil, invalidator, nil)
+
+ user, err := svc.UnbindUserAuthProvider(context.Background(), 12, "linuxdo")
+
+ require.NoError(t, err)
+ require.Equal(t, []string{"linuxdo"}, repo.unboundProviders)
+ require.Equal(t, int64(12), user.ID)
+ require.Equal(t, []int64{12}, invalidator.invalidatedUserIDs)
+
+ summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 12, user)
+ require.NoError(t, err)
+ require.False(t, summaries.LinuxDo.Bound)
+ require.True(t, summaries.LinuxDo.CanBind)
+}
+
+func TestGetProfileIdentitySummaries_HidesBindActionWhenProviderExplicitlyDisabled(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 15,
+ Email: "alice@example.com",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "alice@example.com",
+ },
+ },
+ }
+ settingRepo := &mockUserSettingRepo{
+ values: map[string]string{
+ SettingKeyLinuxDoConnectEnabled: "false",
+ },
+ }
+ svc := NewUserService(repo, settingRepo, nil, nil)
+
+ summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 15, repo.getByIDUser)
+
+ require.NoError(t, err)
+ require.False(t, summaries.LinuxDo.Bound)
+ require.False(t, summaries.LinuxDo.CanBind)
+ require.Empty(t, summaries.LinuxDo.BindStartPath)
+}
+
+func TestGetProfileIdentitySummaries_UsesBindStartRoute(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 16,
+ Email: "alice@example.com",
+ },
+ identities: []UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "alice@example.com",
+ },
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 16, repo.getByIDUser)
+
+ require.NoError(t, err)
+ require.Equal(
+ t,
+ "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile",
+ summaries.LinuxDo.BindStartPath,
+ )
+ require.Equal(
+ t,
+ "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile",
+ summaries.OIDC.BindStartPath,
+ )
+ require.Equal(
+ t,
+ "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile",
+ summaries.WeChat.BindStartPath,
+ )
+}
+
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
repo := &mockUserRepo{}
svc := NewUserService(repo, nil, nil, nil) // billingCache = nil
@@ -154,6 +573,39 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
}, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance")
}
+func TestTouchLastActive_UpdatesWhenStale(t *testing.T) {
+ stale := time.Now().Add(-11 * time.Minute)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 42,
+ LastActiveAt: &stale,
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ svc.TouchLastActive(context.Background(), 42)
+
+ require.Equal(t, []int64{42}, repo.updateLastActiveUserIDs)
+ require.Len(t, repo.updateLastActiveAt, 1)
+ require.WithinDuration(t, time.Now(), repo.updateLastActiveAt[0], 2*time.Second)
+}
+
+func TestTouchLastActive_SkipsWhenRecent(t *testing.T) {
+ recent := time.Now().Add(-time.Minute)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 42,
+ LastActiveAt: &recent,
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ svc.TouchLastActive(context.Background(), 42)
+
+ require.Empty(t, repo.updateLastActiveUserIDs)
+ require.Empty(t, repo.updateLastActiveAt)
+}
+
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
cache := &mockBillingCache{}
@@ -200,3 +652,199 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
require.Equal(t, auth, svc.authCacheInvalidator)
require.Equal(t, cache, svc.billingCache)
}
+
+func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) {
+ raw := []byte("small-avatar")
+ dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
+ expectedSum := sha256.Sum256(raw)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 7,
+ Email: "avatar@example.com",
+ Username: "avatar-user",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{
+ AvatarURL: &dataURL,
+ })
+ require.NoError(t, err)
+ require.Len(t, repo.upsertAvatarArgs, 1)
+ require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider)
+ require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType)
+ require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize)
+ require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256)
+ require.Equal(t, dataURL, updated.AvatarURL)
+ require.Equal(t, "inline", updated.AvatarSource)
+ require.Equal(t, "image/png", updated.AvatarMIME)
+ require.Equal(t, len(raw), updated.AvatarByteSize)
+ require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256)
+}
+
+func TestUpdateProfile_CompressesInlineAvatarToTwentyKilobytes(t *testing.T) {
+ var encoded bytes.Buffer
+ for _, size := range []int{192, 224, 256, 288} {
+ encoded.Reset()
+ var img image.RGBA
+ img.Rect = image.Rect(0, 0, size, size)
+ img.Stride = size * 4
+ img.Pix = make([]byte, size*size*4)
+ for y := 0; y < size; y++ {
+ for x := 0; x < size; x++ {
+ offset := y*img.Stride + x*4
+ img.Pix[offset] = uint8((x*x + y*17) % 255)
+ img.Pix[offset+1] = uint8((y*y + x*29) % 255)
+ img.Pix[offset+2] = uint8(((x * y) + x*13 + y*7) % 255)
+ img.Pix[offset+3] = 0xff
+ }
+ }
+ require.NoError(t, png.Encode(&encoded, &img))
+ if encoded.Len() > 20*1024 && encoded.Len() <= maxInlineAvatarBytes {
+ break
+ }
+ }
+ require.Greater(t, encoded.Len(), 20*1024)
+ require.LessOrEqual(t, encoded.Len(), maxInlineAvatarBytes)
+
+ dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(encoded.Bytes())
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 17,
+ Email: "avatar-compress@example.com",
+ Username: "avatar-compress",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 17, UpdateProfileRequest{
+ AvatarURL: &dataURL,
+ })
+ require.NoError(t, err)
+ require.Len(t, repo.upsertAvatarArgs, 1)
+ require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider)
+ require.LessOrEqual(t, repo.upsertAvatarArgs[0].ByteSize, 20*1024)
+ require.Equal(t, "image/jpeg", repo.upsertAvatarArgs[0].ContentType)
+ require.Contains(t, repo.upsertAvatarArgs[0].URL, "data:image/jpeg;base64,")
+ require.Equal(t, "inline", updated.AvatarSource)
+ require.Equal(t, "image/jpeg", updated.AvatarMIME)
+ require.LessOrEqual(t, updated.AvatarByteSize, 20*1024)
+ require.Contains(t, updated.AvatarURL, "data:image/jpeg;base64,")
+ require.NotEmpty(t, updated.AvatarSHA256)
+}
+
+func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) {
+ raw := make([]byte, maxInlineAvatarBytes+1)
+ dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 8,
+ Email: "large-avatar@example.com",
+ Username: "too-large",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ _, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{
+ AvatarURL: &dataURL,
+ })
+ require.ErrorIs(t, err, ErrAvatarTooLarge)
+ require.Empty(t, repo.upsertAvatarArgs)
+ require.Empty(t, repo.deleteAvatarIDs)
+ require.Zero(t, repo.updateCalls)
+}
+
+func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) {
+ remoteURL := "https://cdn.example.com/avatar.png"
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 9,
+ Email: "remote-avatar@example.com",
+ Username: "remote-avatar",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{
+ AvatarURL: &remoteURL,
+ })
+ require.NoError(t, err)
+ require.Len(t, repo.upsertAvatarArgs, 1)
+ require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider)
+ require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL)
+ require.Equal(t, remoteURL, updated.AvatarURL)
+ require.Equal(t, "remote_url", updated.AvatarSource)
+ require.Zero(t, updated.AvatarByteSize)
+}
+
+func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) {
+ empty := ""
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 10,
+ Email: "delete-avatar@example.com",
+ Username: "delete-avatar",
+ AvatarURL: "https://cdn.example.com/old.png",
+ AvatarSource: "remote_url",
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{
+ AvatarURL: &empty,
+ })
+ require.NoError(t, err)
+ require.Equal(t, []int64{10}, repo.deleteAvatarIDs)
+ require.Empty(t, repo.upsertAvatarArgs)
+ require.Empty(t, updated.AvatarURL)
+ require.Empty(t, updated.AvatarSource)
+}
+
+func TestUpdateProfile_RollsBackAvatarMutationWhenUserUpdateFails(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 11,
+ Email: "rollback@example.com",
+ AvatarURL: "https://cdn.example.com/original.png",
+ AvatarSource: "remote_url",
+ },
+ updateFn: func(context.Context, *User) error {
+ return errors.New("write user failed")
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ remoteURL := "https://cdn.example.com/new.png"
+ _, err := svc.UpdateProfile(context.Background(), 11, UpdateProfileRequest{
+ AvatarURL: &remoteURL,
+ })
+
+ require.EqualError(t, err, "update user: write user failed")
+ require.Equal(t, 1, repo.txCalls)
+ require.Empty(t, repo.upsertAvatarArgs)
+ require.Empty(t, repo.deleteAvatarIDs)
+ require.Equal(t, "https://cdn.example.com/original.png", repo.getByIDUser.AvatarURL)
+ require.Equal(t, "remote_url", repo.getByIDUser.AvatarSource)
+}
+
+func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) {
+ repo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: 12,
+ Email: "profile-avatar@example.com",
+ Username: "profile-avatar",
+ },
+ getAvatarFn: func(context.Context, int64) (*UserAvatar, error) {
+ return &UserAvatar{
+ StorageProvider: "remote_url",
+ URL: "https://cdn.example.com/profile.png",
+ }, nil
+ },
+ }
+ svc := NewUserService(repo, nil, nil, nil)
+
+ user, err := svc.GetProfile(context.Background(), 12)
+ require.NoError(t, err)
+ require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL)
+ require.Equal(t, "remote_url", user.AvatarSource)
+}
diff --git a/backend/internal/service/vertex_service_account.go b/backend/internal/service/vertex_service_account.go
new file mode 100644
index 00000000..4430cf81
--- /dev/null
+++ b/backend/internal/service/vertex_service_account.go
@@ -0,0 +1,345 @@
+package service
+
+import (
+ "bytes"
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "regexp"
+ "strings"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+)
+
+const (
+ vertexDefaultLocation = "us-central1"
+ vertexDefaultTokenURL = "https://oauth2.googleapis.com/token"
+ vertexCloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
+ vertexServiceAccountCacheSkew = 5 * time.Minute
+ vertexLockWaitTime = 200 * time.Millisecond
+ vertexAnthropicVersion = "vertex-2023-10-16"
+)
+
+var (
+ vertexLocationPattern = regexp.MustCompile(`^[a-z0-9-]+$`)
+ vertexAnthropicDatedModelIDPattern = regexp.MustCompile(`^(.+)-([0-9]{8})$`)
+ vertexAnthropicAlreadyDatedIDPattern = regexp.MustCompile(`^.+@[0-9]{8}$`)
+)
+
+type vertexServiceAccountKey struct {
+ Type string `json:"type"`
+ ProjectID string `json:"project_id"`
+ PrivateKeyID string `json:"private_key_id"`
+ PrivateKey string `json:"private_key"`
+ ClientEmail string `json:"client_email"`
+ TokenURI string `json:"token_uri"`
+}
+
+type vertexTokenResponse struct {
+ AccessToken string `json:"access_token"`
+ TokenType string `json:"token_type"`
+ ExpiresIn int64 `json:"expires_in"`
+ Error string `json:"error"`
+ ErrorDesc string `json:"error_description"`
+}
+
+func (a *Account) IsVertexServiceAccount() bool {
+ return a != nil && a.Type == AccountTypeServiceAccount
+}
+
+func (a *Account) VertexProjectID() string {
+ if a == nil {
+ return ""
+ }
+ if v := strings.TrimSpace(a.GetCredential("project_id")); v != "" {
+ return v
+ }
+ key, err := parseVertexServiceAccountKey(a)
+ if err == nil {
+ return strings.TrimSpace(key.ProjectID)
+ }
+ return ""
+}
+
+func (a *Account) VertexLocation(model string) string {
+ if a == nil {
+ return vertexDefaultLocation
+ }
+ if model != "" && a.Credentials != nil {
+ if raw, ok := a.Credentials["vertex_model_locations"].(map[string]any); ok {
+ if loc, ok := raw[model].(string); ok && strings.TrimSpace(loc) != "" {
+ return strings.TrimSpace(loc)
+ }
+ }
+ }
+ if v := strings.TrimSpace(a.GetCredential("location")); v != "" {
+ return v
+ }
+ if v := strings.TrimSpace(a.GetCredential("vertex_location")); v != "" {
+ return v
+ }
+ return vertexDefaultLocation
+}
+
+func parseVertexServiceAccountKey(account *Account) (*vertexServiceAccountKey, error) {
+ if account == nil || account.Credentials == nil {
+ return nil, errors.New("service account credentials not configured")
+ }
+
+ if raw := strings.TrimSpace(account.GetCredential("service_account_json")); raw != "" {
+ return parseVertexServiceAccountJSON([]byte(raw))
+ }
+ if raw := strings.TrimSpace(account.GetCredential("service_account")); raw != "" {
+ return parseVertexServiceAccountJSON([]byte(raw))
+ }
+ if nested, ok := account.Credentials["service_account_json"].(map[string]any); ok {
+ b, _ := json.Marshal(nested)
+ return parseVertexServiceAccountJSON(b)
+ }
+ if nested, ok := account.Credentials["service_account"].(map[string]any); ok {
+ b, _ := json.Marshal(nested)
+ return parseVertexServiceAccountJSON(b)
+ }
+ return nil, errors.New("service_account_json not found in credentials")
+}
+
+func parseVertexServiceAccountJSON(raw []byte) (*vertexServiceAccountKey, error) {
+ var key vertexServiceAccountKey
+ if err := json.Unmarshal(raw, &key); err != nil {
+ return nil, fmt.Errorf("invalid service account json: %w", err)
+ }
+ if strings.TrimSpace(key.ClientEmail) == "" {
+ return nil, errors.New("service account json missing client_email")
+ }
+ if strings.TrimSpace(key.PrivateKey) == "" {
+ return nil, errors.New("service account json missing private_key")
+ }
+ if strings.TrimSpace(key.ProjectID) == "" {
+ return nil, errors.New("service account json missing project_id")
+ }
+ // Always use the well-known Google token endpoint to prevent SSRF via crafted token_uri.
+ key.TokenURI = vertexDefaultTokenURL
+ return &key, nil
+}
+
+func vertexServiceAccountCacheKey(account *Account, key *vertexServiceAccountKey) string {
+ fingerprint := ""
+ if key != nil {
+ sum := sha256.Sum256([]byte(key.ClientEmail + "\x00" + key.PrivateKeyID))
+ fingerprint = hex.EncodeToString(sum[:8])
+ }
+ if fingerprint == "" && account != nil {
+ fingerprint = fmt.Sprintf("account:%d", account.ID)
+ }
+ return "vertex:service_account:" + fingerprint
+}
+
+// getVertexServiceAccountAccessToken obtains an access token for a Vertex service account,
+// using the shared cache and distributed lock to avoid redundant exchanges.
+func getVertexServiceAccountAccessToken(ctx context.Context, cache GeminiTokenCache, account *Account) (string, error) {
+ key, err := parseVertexServiceAccountKey(account)
+ if err != nil {
+ return "", err
+ }
+ cacheKey := vertexServiceAccountCacheKey(account, key)
+
+ if cache != nil {
+ if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
+ return token, nil
+ }
+ }
+
+ locked := false
+ if cache != nil {
+ var lockErr error
+ locked, lockErr = cache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
+ if lockErr == nil && locked {
+ defer func() { _ = cache.ReleaseRefreshLock(ctx, cacheKey) }()
+ } else if lockErr != nil {
+ slog.Warn("vertex_service_account_token_lock_failed", "account_id", account.ID, "error", lockErr)
+ } else {
+ time.Sleep(vertexLockWaitTime)
+ if token, err := cache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
+ return token, nil
+ }
+ }
+ }
+
+ accessToken, ttl, err := exchangeVertexServiceAccountToken(ctx, key)
+ if err != nil {
+ return "", err
+ }
+ if cache != nil {
+ _ = cache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
+ }
+ return accessToken, nil
+}
+
+func exchangeVertexServiceAccountToken(ctx context.Context, key *vertexServiceAccountKey) (string, time.Duration, error) {
+ now := time.Now()
+ claims := jwt.MapClaims{
+ "iss": key.ClientEmail,
+ "scope": vertexCloudPlatformScope,
+ "aud": key.TokenURI,
+ "iat": now.Unix(),
+ "exp": now.Add(time.Hour).Unix(),
+ }
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ if strings.TrimSpace(key.PrivateKeyID) != "" {
+ token.Header["kid"] = key.PrivateKeyID
+ }
+ privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(key.PrivateKey))
+ if err != nil {
+ return "", 0, fmt.Errorf("parse service account private key: %w", err)
+ }
+ assertion, err := token.SignedString(privateKey)
+ if err != nil {
+ return "", 0, fmt.Errorf("sign service account assertion: %w", err)
+ }
+
+ values := url.Values{}
+ values.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
+ values.Set("assertion", assertion)
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, key.TokenURI, strings.NewReader(values.Encode()))
+ if err != nil {
+ return "", 0, err
+ }
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ client := &http.Client{Timeout: 15 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return "", 0, fmt.Errorf("service account token request failed: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ var parsed vertexTokenResponse
+ _ = json.Unmarshal(body, &parsed)
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ msg := strings.TrimSpace(parsed.ErrorDesc)
+ if msg == "" {
+ msg = strings.TrimSpace(parsed.Error)
+ }
+ if msg == "" {
+ msg = string(bytes.TrimSpace(body))
+ }
+ return "", 0, fmt.Errorf("service account token request returned %d: %s", resp.StatusCode, msg)
+ }
+ if strings.TrimSpace(parsed.AccessToken) == "" {
+ return "", 0, errors.New("service account token response missing access_token")
+ }
+ ttl := time.Duration(parsed.ExpiresIn) * time.Second
+ if ttl <= 0 {
+ ttl = time.Hour
+ }
+ if ttl > vertexServiceAccountCacheSkew {
+ ttl -= vertexServiceAccountCacheSkew
+ }
+ return parsed.AccessToken, ttl, nil
+}
+
+func buildVertexGeminiURL(projectID, location, model, action string, stream bool) (string, error) {
+ projectID = strings.TrimSpace(projectID)
+ location = strings.TrimSpace(location)
+ model = strings.TrimSpace(model)
+ action = strings.TrimSpace(action)
+ if projectID == "" {
+ return "", errors.New("vertex project_id is required")
+ }
+ if location == "" {
+ location = vertexDefaultLocation
+ }
+ if !vertexLocationPattern.MatchString(location) {
+ return "", fmt.Errorf("invalid vertex location: %s", location)
+ }
+ if model == "" {
+ return "", errors.New("vertex model is required")
+ }
+ switch action {
+ case "generateContent", "streamGenerateContent", "countTokens":
+ default:
+ return "", fmt.Errorf("unsupported vertex gemini action: %s", action)
+ }
+ host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
+ if location == "global" {
+ host = "aiplatform.googleapis.com"
+ }
+ u := fmt.Sprintf(
+ "https://%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
+ host,
+ url.PathEscape(projectID),
+ url.PathEscape(location),
+ url.PathEscape(model),
+ action,
+ )
+ if stream {
+ u += "?alt=sse"
+ }
+ return u, nil
+}
+
+func buildVertexAnthropicURL(projectID, location, model string, stream bool) (string, error) {
+ projectID = strings.TrimSpace(projectID)
+ location = strings.TrimSpace(location)
+ model = strings.TrimSpace(model)
+ if projectID == "" {
+ return "", errors.New("vertex project_id is required")
+ }
+ if location == "" {
+ location = vertexDefaultLocation
+ }
+ if !vertexLocationPattern.MatchString(location) {
+ return "", fmt.Errorf("invalid vertex location: %s", location)
+ }
+ if model == "" {
+ return "", errors.New("vertex model is required")
+ }
+ action := "rawPredict"
+ if stream {
+ action = "streamRawPredict"
+ }
+ host := fmt.Sprintf("%s-aiplatform.googleapis.com", location)
+ if location == "global" {
+ host = "aiplatform.googleapis.com"
+ }
+ escapedModel := strings.ReplaceAll(url.PathEscape(model), "%40", "@")
+ return fmt.Sprintf(
+ "https://%s/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
+ host,
+ url.PathEscape(projectID),
+ url.PathEscape(location),
+ escapedModel,
+ action,
+ ), nil
+}
+
+func normalizeVertexAnthropicModelID(model string) string {
+ model = strings.TrimSpace(model)
+ if model == "" || vertexAnthropicAlreadyDatedIDPattern.MatchString(model) {
+ return model
+ }
+ if m := vertexAnthropicDatedModelIDPattern.FindStringSubmatch(model); len(m) == 3 {
+ return m[1] + "@" + m[2]
+ }
+ return model
+}
+
+func buildVertexAnthropicRequestBody(body []byte) ([]byte, error) {
+ var payload map[string]any
+ if err := json.Unmarshal(body, &payload); err != nil {
+ return nil, fmt.Errorf("parse anthropic vertex request body: %w", err)
+ }
+ delete(payload, "model")
+ payload["anthropic_version"] = vertexAnthropicVersion
+ return json.Marshal(payload)
+}
diff --git a/backend/internal/service/vertex_service_account_test.go b/backend/internal/service/vertex_service_account_test.go
new file mode 100644
index 00000000..519f5b2f
--- /dev/null
+++ b/backend/internal/service/vertex_service_account_test.go
@@ -0,0 +1,77 @@
+package service
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestBuildVertexGeminiURL(t *testing.T) {
+ got, err := buildVertexGeminiURL("my-project", "us-central1", "gemini-3-pro", "streamGenerateContent", true)
+ require.NoError(t, err)
+ require.Equal(t, "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-3-pro:streamGenerateContent?alt=sse", got)
+}
+
+func TestBuildVertexGeminiURLUsesGlobalEndpointHost(t *testing.T) {
+ got, err := buildVertexGeminiURL("my-project", "global", "gemini-3-flash-preview", "streamGenerateContent", true)
+ require.NoError(t, err)
+ require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/google/models/gemini-3-flash-preview:streamGenerateContent?alt=sse", got)
+}
+
+func TestBuildVertexAnthropicURL(t *testing.T) {
+ got, err := buildVertexAnthropicURL("my-project", "us-east5", "claude-sonnet-4-5@20250929", false)
+ require.NoError(t, err)
+ require.Equal(t, "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5@20250929:rawPredict", got)
+}
+
+func TestBuildVertexAnthropicURLUsesGlobalEndpointHost(t *testing.T) {
+ got, err := buildVertexAnthropicURL("my-project", "global", "claude-haiku-4-5@20251001", true)
+ require.NoError(t, err)
+ require.Equal(t, "https://aiplatform.googleapis.com/v1/projects/my-project/locations/global/publishers/anthropic/models/claude-haiku-4-5@20251001:streamRawPredict", got)
+}
+
+func TestNormalizeVertexAnthropicModelID(t *testing.T) {
+ require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5-20250929"))
+ require.Equal(t, "claude-sonnet-4-5@20250929", normalizeVertexAnthropicModelID("claude-sonnet-4-5@20250929"))
+ require.Equal(t, "claude-sonnet-4-6", normalizeVertexAnthropicModelID("claude-sonnet-4-6"))
+}
+
+func TestBuildVertexAnthropicRequestBody(t *testing.T) {
+ got, err := buildVertexAnthropicRequestBody([]byte(`{"model":"claude-sonnet-4-5","anthropic_version":"2023-06-01","max_tokens":64,"messages":[{"role":"user","content":"hi"}]}`))
+ require.NoError(t, err)
+ require.Equal(t, "", gjson.GetBytes(got, "model").String())
+ require.Equal(t, vertexAnthropicVersion, gjson.GetBytes(got, "anthropic_version").String())
+ require.Equal(t, int64(64), gjson.GetBytes(got, "max_tokens").Int())
+ require.Equal(t, "hi", gjson.GetBytes(got, "messages.0.content").String())
+}
+
+func TestBuildVertexGeminiURLRejectsInvalidLocation(t *testing.T) {
+ _, err := buildVertexGeminiURL("my-project", "us-central1/path", "gemini-3-pro", "generateContent", false)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "invalid vertex location")
+}
+
+func TestParseVertexServiceAccountKey(t *testing.T) {
+ raw := `{
+ "type": "service_account",
+ "project_id": "vertex-proj",
+ "private_key_id": "kid",
+ "private_key": "-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
+ "client_email": "svc@vertex-proj.iam.gserviceaccount.com"
+ }`
+ account := &Account{
+ Type: AccountTypeServiceAccount,
+ Platform: PlatformGemini,
+ Credentials: map[string]any{
+ "service_account_json": raw,
+ },
+ }
+ key, err := parseVertexServiceAccountKey(account)
+ require.NoError(t, err)
+ require.Equal(t, "vertex-proj", key.ProjectID)
+ require.Equal(t, "svc@vertex-proj.iam.gserviceaccount.com", key.ClientEmail)
+ require.Equal(t, vertexDefaultTokenURL, key.TokenURI)
+ require.True(t, strings.Contains(key.PrivateKey, "BEGIN PRIVATE KEY"))
+}
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index 9f33c46a..8b50e478 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -39,6 +39,11 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
return NewEmailQueueService(emailService, 3)
}
+// ProvideOAuthRefreshAPI creates OAuthRefreshAPI with the default lock TTL.
+func ProvideOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI {
+ return NewOAuthRefreshAPI(accountRepo, tokenCache)
+}
+
// ProvideTokenRefreshService creates and starts TokenRefreshService
func ProvideTokenRefreshService(
accountRepo AccountRepository,
@@ -210,11 +215,13 @@ func ProvideRateLimitService(
geminiQuotaService *GeminiQuotaService,
tempUnschedCache TempUnschedCache,
timeoutCounterCache TimeoutCounterCache,
+ openAI403CounterCache OpenAI403CounterCache,
settingService *SettingService,
tokenCacheInvalidator TokenCacheInvalidator,
) *RateLimitService {
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
svc.SetTimeoutCounterCache(timeoutCounterCache)
+ svc.SetOpenAI403CounterCache(openAI403CounterCache)
svc.SetSettingService(settingService)
svc.SetTokenCacheInvalidator(tokenCacheInvalidator)
return svc
@@ -262,13 +269,16 @@ func ProvideOpsAlertEvaluatorService(
}
// ProvideOpsCleanupService creates and starts OpsCleanupService (cron scheduled).
+// channelMonitorSvc 让维护任务(聚合 + 历史/聚合软删)跟随 ops 清理 cron 一起跑,
+// 共享 leader lock + heartbeat。
func ProvideOpsCleanupService(
opsRepo OpsRepository,
db *sql.DB,
redisClient *redis.Client,
cfg *config.Config,
+ channelMonitorSvc *ChannelMonitorService,
) *OpsCleanupService {
- svc := NewOpsCleanupService(opsRepo, db, redisClient, cfg)
+ svc := NewOpsCleanupService(opsRepo, db, redisClient, cfg, channelMonitorSvc)
svc.Start()
return svc
}
@@ -381,12 +391,41 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit
return svc
}
+// ProvideBillingCacheService wires BillingCacheService with its RPM dependencies.
+func ProvideBillingCacheService(
+ cache BillingCache,
+ userRepo UserRepository,
+ subRepo UserSubscriptionRepository,
+ apiKeyRepo APIKeyRepository,
+ rpmCache UserRPMCache,
+ rateRepo UserGroupRateRepository,
+ cfg *config.Config,
+) *BillingCacheService {
+ return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg)
+}
+
+// ProvideAPIKeyService wires APIKeyService and connects rate-limit cache invalidation.
+func ProvideAPIKeyService(
+ apiKeyRepo APIKeyRepository,
+ userRepo UserRepository,
+ groupRepo GroupRepository,
+ userSubRepo UserSubscriptionRepository,
+ userGroupRateRepo UserGroupRateRepository,
+ cache APIKeyCache,
+ cfg *config.Config,
+ billingCacheService *BillingCacheService,
+) *APIKeyService {
+ svc := NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, userGroupRateRepo, cache, cfg)
+ svc.SetRateLimitCacheInvalidator(billingCacheService)
+ return svc
+}
+
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
- NewAPIKeyService,
+ ProvideAPIKeyService,
ProvideAPIKeyAuthCacheInvalidator,
NewGroupService,
NewAccountService,
@@ -397,7 +436,7 @@ var ProviderSet = wire.NewSet(
NewDashboardService,
ProvidePricingService,
NewBillingService,
- NewBillingCacheService,
+ ProvideBillingCacheService,
NewAnnouncementService,
NewAdminService,
NewGatewayService,
@@ -409,7 +448,7 @@ var ProviderSet = wire.NewSet(
NewCompositeTokenCacheInvalidator,
wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
NewAntigravityOAuthService,
- NewOAuthRefreshAPI,
+ ProvideOAuthRefreshAPI,
ProvideGeminiTokenProvider,
NewGeminiMessagesCompatService,
ProvideAntigravityTokenProvider,
@@ -463,10 +502,14 @@ var ProviderSet = wire.NewSet(
NewGroupCapacityService,
NewChannelService,
NewModelPricingResolver,
+ NewAffiliateService,
ProvidePaymentConfigService,
NewPaymentService,
ProvidePaymentOrderExpiryService,
ProvideBalanceNotifyService,
+ ProvideChannelMonitorService,
+ ProvideChannelMonitorRunner,
+ NewChannelMonitorRequestTemplateService,
)
// ProvidePaymentConfigService wraps NewPaymentConfigService to accept the named
@@ -486,3 +529,23 @@ func ProvidePaymentOrderExpiryService(paymentSvc *PaymentService) *PaymentOrderE
svc.Start()
return svc
}
+
+// ProvideChannelMonitorService 创建渠道监控服务(CRUD + RunCheck + 用户视图聚合)。
+// 加密器复用 wire 中已注入的 SecretEncryptor(AES-256-GCM)。
+func ProvideChannelMonitorService(
+ repo ChannelMonitorRepository,
+ encryptor SecretEncryptor,
+) *ChannelMonitorService {
+ return NewChannelMonitorService(repo, encryptor)
+}
+
+// ProvideChannelMonitorRunner 创建并启动渠道监控调度器。
+// 通过 SetScheduler 注入回 service 后再 Start,确保启动时加载所有 enabled monitor,
+// 后续 CRUD 也能即时同步任务表。Runner.Stop 由 cleanup function 调用。
+// settingService 用于 runner 每次 fire 读取功能开关。
+func ProvideChannelMonitorRunner(svc *ChannelMonitorService, settingService *SettingService) *ChannelMonitorRunner {
+ r := NewChannelMonitorRunner(svc, settingService)
+ svc.SetScheduler(r)
+ r.Start()
+ return r
+}
diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go
index 89d09eef..2279d913 100644
--- a/backend/internal/web/embed_on.go
+++ b/backend/internal/web/embed_on.go
@@ -301,11 +301,13 @@ func shouldBypassEmbeddedFrontend(path string) bool {
return strings.HasPrefix(trimmed, "/api/") ||
strings.HasPrefix(trimmed, "/v1/") ||
strings.HasPrefix(trimmed, "/v1beta/") ||
+ strings.HasPrefix(trimmed, "/backend-api/") ||
strings.HasPrefix(trimmed, "/antigravity/") ||
strings.HasPrefix(trimmed, "/setup/") ||
trimmed == "/health" ||
trimmed == "/responses" ||
- strings.HasPrefix(trimmed, "/responses/")
+ strings.HasPrefix(trimmed, "/responses/") ||
+ strings.HasPrefix(trimmed, "/images/")
}
func serveIndexHTML(c *gin.Context, fsys fs.FS) {
diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go
index 4127a7a6..583d98a0 100644
--- a/backend/internal/web/embed_test.go
+++ b/backend/internal/web/embed_test.go
@@ -434,6 +434,8 @@ func TestFrontendServer_Middleware(t *testing.T) {
"/api/v1/users",
"/v1/models",
"/v1beta/chat",
+ "/backend-api/codex/responses",
+ "/backend-api/codex/responses/compact",
"/antigravity/test",
"/setup/init",
"/health",
@@ -636,6 +638,8 @@ func TestServeEmbeddedFrontend(t *testing.T) {
"/api/users",
"/v1/models",
"/v1beta/chat",
+ "/backend-api/codex/responses",
+ "/backend-api/codex/responses/compact",
"/antigravity/test",
"/setup/init",
"/health",
diff --git a/backend/migrations/108_auth_identity_foundation_core.sql b/backend/migrations/108_auth_identity_foundation_core.sql
new file mode 100644
index 00000000..117e3ca3
--- /dev/null
+++ b/backend/migrations/108_auth_identity_foundation_core.sql
@@ -0,0 +1,141 @@
+ALTER TABLE users
+ADD COLUMN IF NOT EXISTS signup_source VARCHAR(20) NOT NULL DEFAULT 'email',
+ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ NULL,
+ADD COLUMN IF NOT EXISTS last_active_at TIMESTAMPTZ NULL;
+
+UPDATE users
+SET signup_source = 'email'
+WHERE signup_source IS NULL OR signup_source = '';
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'users_signup_source_check'
+ ) THEN
+ ALTER TABLE users
+ ADD CONSTRAINT users_signup_source_check
+ CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc'));
+ END IF;
+END $$;
+
+CREATE TABLE IF NOT EXISTS auth_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ provider_subject TEXT NOT NULL,
+ verified_at TIMESTAMPTZ NULL,
+ issuer TEXT NULL,
+ metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT auth_identities_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identities_provider_subject_key
+ ON auth_identities (provider_type, provider_key, provider_subject);
+
+CREATE INDEX IF NOT EXISTS auth_identities_user_id_idx
+ ON auth_identities (user_id);
+
+CREATE INDEX IF NOT EXISTS auth_identities_user_provider_idx
+ ON auth_identities (user_id, provider_type);
+
+CREATE TABLE IF NOT EXISTS auth_identity_channels (
+ id BIGSERIAL PRIMARY KEY,
+ identity_id BIGINT NOT NULL REFERENCES auth_identities(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ channel VARCHAR(20) NOT NULL,
+ channel_app_id TEXT NOT NULL,
+ channel_subject TEXT NOT NULL,
+ metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT auth_identity_channels_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_channels_channel_key
+ ON auth_identity_channels (provider_type, provider_key, channel, channel_app_id, channel_subject);
+
+CREATE INDEX IF NOT EXISTS auth_identity_channels_identity_id_idx
+ ON auth_identity_channels (identity_id);
+
+CREATE TABLE IF NOT EXISTS pending_auth_sessions (
+ id BIGSERIAL PRIMARY KEY,
+ session_token VARCHAR(255) NOT NULL,
+ intent VARCHAR(40) NOT NULL,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ provider_subject TEXT NOT NULL,
+ target_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
+ redirect_to TEXT NOT NULL DEFAULT '',
+ resolved_email TEXT NOT NULL DEFAULT '',
+ registration_password_hash TEXT NOT NULL DEFAULT '',
+ upstream_identity_claims JSONB NOT NULL DEFAULT '{}'::jsonb,
+ local_flow_state JSONB NOT NULL DEFAULT '{}'::jsonb,
+ browser_session_key TEXT NOT NULL DEFAULT '',
+ completion_code_hash TEXT NOT NULL DEFAULT '',
+ completion_code_expires_at TIMESTAMPTZ NULL,
+ email_verified_at TIMESTAMPTZ NULL,
+ password_verified_at TIMESTAMPTZ NULL,
+ totp_verified_at TIMESTAMPTZ NULL,
+ expires_at TIMESTAMPTZ NOT NULL,
+ consumed_at TIMESTAMPTZ NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT pending_auth_sessions_intent_check
+ CHECK (intent IN ('login', 'bind_current_user', 'adopt_existing_user_by_email')),
+ CONSTRAINT pending_auth_sessions_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS pending_auth_sessions_session_token_key
+ ON pending_auth_sessions (session_token);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_target_user_id_idx
+ ON pending_auth_sessions (target_user_id);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_expires_at_idx
+ ON pending_auth_sessions (expires_at);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_provider_idx
+ ON pending_auth_sessions (provider_type, provider_key, provider_subject);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_completion_code_idx
+ ON pending_auth_sessions (completion_code_hash);
+
+CREATE TABLE IF NOT EXISTS identity_adoption_decisions (
+ id BIGSERIAL PRIMARY KEY,
+ pending_auth_session_id BIGINT NOT NULL REFERENCES pending_auth_sessions(id) ON DELETE CASCADE,
+ identity_id BIGINT NULL REFERENCES auth_identities(id) ON DELETE SET NULL,
+ adopt_display_name BOOLEAN NOT NULL DEFAULT FALSE,
+ adopt_avatar BOOLEAN NOT NULL DEFAULT FALSE,
+ decided_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS identity_adoption_decisions_pending_auth_session_id_key
+ ON identity_adoption_decisions (pending_auth_session_id);
+
+CREATE INDEX IF NOT EXISTS identity_adoption_decisions_identity_id_idx
+ ON identity_adoption_decisions (identity_id);
+
+CREATE TABLE IF NOT EXISTS auth_identity_migration_reports (
+ id BIGSERIAL PRIMARY KEY,
+ report_type VARCHAR(40) NOT NULL,
+ report_key TEXT NOT NULL,
+ details JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS auth_identity_migration_reports_type_idx
+ ON auth_identity_migration_reports (report_type);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_migration_reports_type_key
+ ON auth_identity_migration_reports (report_type, report_key);
diff --git a/backend/migrations/108a_widen_auth_identity_migration_report_type.sql b/backend/migrations/108a_widen_auth_identity_migration_report_type.sql
new file mode 100644
index 00000000..bc170fb8
--- /dev/null
+++ b/backend/migrations/108a_widen_auth_identity_migration_report_type.sql
@@ -0,0 +1,14 @@
+DO $$
+BEGIN
+ IF EXISTS (
+ SELECT 1
+ FROM information_schema.columns
+ WHERE table_schema = 'public'
+ AND table_name = 'auth_identity_migration_reports'
+ AND column_name = 'report_type'
+ AND COALESCE(character_maximum_length, 0) < 80
+ ) THEN
+ ALTER TABLE auth_identity_migration_reports
+ ALTER COLUMN report_type TYPE VARCHAR(80);
+ END IF;
+END $$;
diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql
new file mode 100644
index 00000000..ddbbedbc
--- /dev/null
+++ b/backend/migrations/109_auth_identity_compat_backfill.sql
@@ -0,0 +1,125 @@
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'email',
+ 'email',
+ LOWER(BTRIM(u.email)),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'users.email',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND BTRIM(COALESCE(u.email, '')) <> ''
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@linuxdo-connect.invalid')) <> '@linuxdo-connect.invalid'
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@oidc-connect.invalid')) <> '@oidc-connect.invalid'
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@wechat-connect.invalid')) <> '@wechat-connect.invalid'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'linuxdo',
+ 'linuxdo',
+ SUBSTRING(BTRIM(u.email) FROM '(?i)^linuxdo-(.+)@linuxdo-connect\.invalid$'),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'synthetic_email',
+ 'legacy_email', BTRIM(u.email),
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^linuxdo-.+@linuxdo-connect\.invalid$'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'wechat',
+ 'wechat',
+ SUBSTRING(BTRIM(u.email) FROM '(?i)^wechat-(.+)@wechat-connect\.invalid$'),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'synthetic_email',
+ 'legacy_email', BTRIM(u.email),
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+UPDATE users
+SET signup_source = 'linuxdo'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^linuxdo-.+@linuxdo-connect\.invalid$';
+
+UPDATE users
+SET signup_source = 'wechat'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^wechat-.+@wechat-connect\.invalid$';
+
+UPDATE users
+SET signup_source = 'oidc'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^oidc-.+@oidc-connect\.invalid$';
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'oidc_synthetic_email_requires_manual_recovery',
+ CAST(u.id AS TEXT),
+ jsonb_build_object(
+ 'user_id', u.id,
+ 'email', LOWER(BTRIM(u.email)),
+ 'reason', 'cannot recover issuer_plus_sub deterministically from synthetic email alone',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^oidc-.+@oidc-connect\.invalid$'
+ON CONFLICT (report_type, report_key) DO NOTHING;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ CAST(u.id AS TEXT),
+ jsonb_build_object(
+ 'user_id', u.id,
+ 'email', LOWER(BTRIM(u.email)),
+ 'reason', 'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identities ai
+ WHERE ai.user_id = u.id
+ AND ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ )
+ON CONFLICT (report_type, report_key) DO NOTHING;
diff --git a/backend/migrations/110_pending_auth_and_provider_default_grants.sql b/backend/migrations/110_pending_auth_and_provider_default_grants.sql
new file mode 100644
index 00000000..f59b2188
--- /dev/null
+++ b/backend/migrations/110_pending_auth_and_provider_default_grants.sql
@@ -0,0 +1,59 @@
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ grant_reason VARCHAR(20) NOT NULL DEFAULT 'first_bind',
+ granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT user_provider_default_grants_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')),
+ CONSTRAINT user_provider_default_grants_reason_check
+ CHECK (grant_reason IN ('signup', 'first_bind'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS user_provider_default_grants_user_provider_reason_key
+ ON user_provider_default_grants (user_id, provider_type, grant_reason);
+
+CREATE INDEX IF NOT EXISTS user_provider_default_grants_user_id_idx
+ ON user_provider_default_grants (user_id);
+
+CREATE TABLE IF NOT EXISTS user_avatars (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ storage_provider VARCHAR(20) NOT NULL DEFAULT 'database',
+ storage_key TEXT NOT NULL DEFAULT '',
+ url TEXT NOT NULL DEFAULT '',
+ content_type VARCHAR(100) NOT NULL DEFAULT '',
+ byte_size INT NOT NULL DEFAULT 0,
+ sha256 VARCHAR(64) NOT NULL DEFAULT '',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS user_avatars_user_id_key
+ ON user_avatars (user_id);
+
+INSERT INTO settings (key, value)
+VALUES
+ ('auth_source_default_email_balance', '0'),
+ ('auth_source_default_email_concurrency', '5'),
+ ('auth_source_default_email_subscriptions', '[]'),
+ ('auth_source_default_email_grant_on_signup', 'false'),
+ ('auth_source_default_email_grant_on_first_bind', 'false'),
+ ('auth_source_default_linuxdo_balance', '0'),
+ ('auth_source_default_linuxdo_concurrency', '5'),
+ ('auth_source_default_linuxdo_subscriptions', '[]'),
+ ('auth_source_default_linuxdo_grant_on_signup', 'false'),
+ ('auth_source_default_linuxdo_grant_on_first_bind', 'false'),
+ ('auth_source_default_oidc_balance', '0'),
+ ('auth_source_default_oidc_concurrency', '5'),
+ ('auth_source_default_oidc_subscriptions', '[]'),
+ ('auth_source_default_oidc_grant_on_signup', 'false'),
+ ('auth_source_default_oidc_grant_on_first_bind', 'false'),
+ ('auth_source_default_wechat_balance', '0'),
+ ('auth_source_default_wechat_concurrency', '5'),
+ ('auth_source_default_wechat_subscriptions', '[]'),
+ ('auth_source_default_wechat_grant_on_signup', 'false'),
+ ('auth_source_default_wechat_grant_on_first_bind', 'false'),
+ ('force_email_on_third_party_signup', 'false')
+ON CONFLICT (key) DO NOTHING;
diff --git a/backend/migrations/111_payment_routing_and_scheduler_flags.sql b/backend/migrations/111_payment_routing_and_scheduler_flags.sql
new file mode 100644
index 00000000..f222a8d4
--- /dev/null
+++ b/backend/migrations/111_payment_routing_and_scheduler_flags.sql
@@ -0,0 +1,8 @@
+INSERT INTO settings (key, value)
+VALUES
+ ('payment_visible_method_alipay_source', ''),
+ ('payment_visible_method_wxpay_source', ''),
+ ('payment_visible_method_alipay_enabled', 'false'),
+ ('payment_visible_method_wxpay_enabled', 'false'),
+ ('openai_advanced_scheduler_enabled', 'false')
+ON CONFLICT (key) DO NOTHING;
diff --git a/backend/migrations/112_add_payment_order_provider_key_snapshot.sql b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql
new file mode 100644
index 00000000..d331b824
--- /dev/null
+++ b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql
@@ -0,0 +1,10 @@
+ALTER TABLE payment_orders ADD COLUMN IF NOT EXISTS provider_key VARCHAR(30);
+
+UPDATE payment_orders
+SET provider_key = (
+ SELECT provider_key
+ FROM payment_provider_instances
+ WHERE CAST(id AS TEXT) = payment_orders.provider_instance_id
+)
+WHERE provider_key IS NULL
+ AND provider_instance_id IS NOT NULL;
diff --git a/backend/migrations/113_normalize_legacy_wechat_provider_key.sql b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql
new file mode 100644
index 00000000..15610af0
--- /dev/null
+++ b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql
@@ -0,0 +1,89 @@
+UPDATE auth_identities AS ai
+SET
+ provider_key = 'wechat-main',
+ metadata = COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'legacy_provider_key', 'wechat',
+ 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key'
+ ),
+ updated_at = NOW()
+WHERE ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identities AS canon
+ WHERE canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.provider_subject = ai.provider_subject
+ );
+
+UPDATE auth_identity_channels AS channel
+SET
+ provider_key = 'wechat-main',
+ metadata = COALESCE(channel.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'legacy_provider_key', 'wechat',
+ 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key'
+ ),
+ updated_at = NOW()
+WHERE channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identity_channels AS canon
+ WHERE canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.channel = channel.channel
+ AND canon.channel_app_id = channel.channel_app_id
+ AND canon.channel_subject = channel.channel_subject
+ );
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_provider_key_conflict',
+ CAST(ai.id AS TEXT),
+ jsonb_build_object(
+ 'legacy_identity_id', ai.id,
+ 'legacy_user_id', ai.user_id,
+ 'provider_subject', ai.provider_subject,
+ 'canonical_identity_id', canon.id,
+ 'canonical_user_id', canon.user_id,
+ 'same_user', canon.user_id = ai.user_id,
+ 'migration', '113_normalize_legacy_wechat_provider_key'
+ )
+FROM auth_identities AS ai
+JOIN auth_identities AS canon
+ ON canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.provider_subject = ai.provider_subject
+WHERE ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ON CONFLICT (report_type, report_key) DO NOTHING;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_channel_provider_key_conflict',
+ CAST(channel.id AS TEXT),
+ jsonb_build_object(
+ 'legacy_channel_id', channel.id,
+ 'legacy_identity_id', channel.identity_id,
+ 'canonical_channel_id', canon.id,
+ 'canonical_identity_id', canon.identity_id,
+ 'channel', channel.channel,
+ 'channel_app_id', channel.channel_app_id,
+ 'channel_subject', channel.channel_subject,
+ 'same_user', COALESCE(legacy_identity.user_id = canonical_identity.user_id, FALSE),
+ 'migration', '113_normalize_legacy_wechat_provider_key'
+ )
+FROM auth_identity_channels AS channel
+JOIN auth_identity_channels AS canon
+ ON canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.channel = channel.channel
+ AND canon.channel_app_id = channel.channel_app_id
+ AND canon.channel_subject = channel.channel_subject
+LEFT JOIN auth_identities AS legacy_identity
+ ON legacy_identity.id = channel.identity_id
+LEFT JOIN auth_identities AS canonical_identity
+ ON canonical_identity.id = canon.identity_id
+WHERE channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat'
+ON CONFLICT (report_type, report_key) DO NOTHING;
diff --git a/backend/migrations/114_auth_identity_migration_report_resolution.sql b/backend/migrations/114_auth_identity_migration_report_resolution.sql
new file mode 100644
index 00000000..f84bf822
--- /dev/null
+++ b/backend/migrations/114_auth_identity_migration_report_resolution.sql
@@ -0,0 +1,11 @@
+ALTER TABLE auth_identity_migration_reports
+ ADD COLUMN IF NOT EXISTS resolved_at TIMESTAMPTZ NULL;
+
+ALTER TABLE auth_identity_migration_reports
+ ADD COLUMN IF NOT EXISTS resolved_by_user_id BIGINT NULL;
+
+ALTER TABLE auth_identity_migration_reports
+ ADD COLUMN IF NOT EXISTS resolution_note TEXT NOT NULL DEFAULT '';
+
+CREATE INDEX IF NOT EXISTS idx_auth_identity_migration_reports_resolved_at
+ ON auth_identity_migration_reports (resolved_at);
diff --git a/backend/migrations/115_auth_identity_legacy_external_backfill.sql b/backend/migrations/115_auth_identity_legacy_external_backfill.sql
new file mode 100644
index 00000000..264da3c9
--- /dev/null
+++ b/backend/migrations/115_auth_identity_legacy_external_backfill.sql
@@ -0,0 +1,268 @@
+CREATE OR REPLACE FUNCTION public.__migration_115_safe_legacy_metadata_jsonb(input_text TEXT)
+RETURNS JSONB
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN '{}'::jsonb;
+ END IF;
+
+ BEGIN
+ parsed := input_text::jsonb;
+ EXCEPTION
+ WHEN OTHERS THEN
+ RETURN '{}'::jsonb;
+ END;
+
+ IF jsonb_typeof(parsed) = 'object' THEN
+ RETURN parsed;
+ END IF;
+
+ RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed);
+END;
+$$;
+
+DO $$
+BEGIN
+ IF to_regclass('public.user_external_identities') IS NULL THEN
+ RETURN;
+ END IF;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_username) AS provider_username,
+ BTRIM(uei.display_name) AS display_name,
+ public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ uei.created_at,
+ uei.updated_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+),
+legacy_subjects AS (
+ SELECT
+ provider_user_id AS provider_subject,
+ COUNT(DISTINCT user_id) AS distinct_user_count
+ FROM legacy
+ GROUP BY provider_user_id
+),
+canonical_legacy AS (
+ SELECT
+ legacy.*,
+ ROW_NUMBER() OVER (
+ PARTITION BY legacy.provider_user_id
+ ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC
+ ) AS canonical_row_num
+ FROM legacy
+ JOIN legacy_subjects AS subjects
+ ON subjects.provider_subject = legacy.provider_user_id
+ AND subjects.distinct_user_count = 1
+)
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ 'linuxdo',
+ 'linuxdo',
+ legacy.provider_user_id,
+ COALESCE(legacy.updated_at, legacy.created_at, NOW()),
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM canonical_legacy AS legacy
+WHERE legacy.canonical_row_num = 1
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_union_id) AS provider_union_id,
+ BTRIM(uei.provider_username) AS provider_username,
+ BTRIM(uei.display_name) AS display_name,
+ public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ uei.created_at,
+ uei.updated_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+),
+legacy_subjects AS (
+ SELECT
+ provider_union_id AS provider_subject,
+ COUNT(DISTINCT user_id) AS distinct_user_count
+ FROM legacy
+ GROUP BY provider_union_id
+),
+canonical_legacy AS (
+ SELECT
+ legacy.*,
+ ROW_NUMBER() OVER (
+ PARTITION BY legacy.provider_union_id
+ ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC
+ ) AS canonical_row_num
+ FROM legacy
+ JOIN legacy_subjects AS subjects
+ ON subjects.provider_subject = legacy.provider_union_id
+ AND subjects.distinct_user_count = 1
+)
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ 'wechat',
+ 'wechat-main',
+ legacy.provider_union_id,
+ COALESCE(legacy.updated_at, legacy.created_at, NOW()),
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_union_id', legacy.provider_union_id,
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM canonical_legacy AS legacy
+WHERE legacy.canonical_row_num = 1
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_union_id) AS provider_union_id,
+ BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id,
+ meta.metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ CROSS JOIN LATERAL (
+ SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ ) AS meta
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+),
+legacy_subjects AS (
+ SELECT
+ provider_union_id AS provider_subject,
+ COUNT(DISTINCT user_id) AS distinct_user_count
+ FROM legacy
+ GROUP BY provider_union_id
+)
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+SELECT
+ ai.id,
+ 'wechat',
+ 'wechat-main',
+ legacy.channel,
+ legacy.channel_app_id,
+ legacy.provider_user_id,
+ legacy.metadata_json || jsonb_build_object(
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM legacy
+JOIN legacy_subjects AS subjects
+ ON subjects.provider_subject = legacy.provider_union_id
+ AND subjects.distinct_user_count = 1
+JOIN auth_identities AS ai
+ ON ai.user_id = legacy.user_id
+ AND ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat-main'
+ AND ai.provider_subject = legacy.provider_union_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND legacy.provider_user_id <> ''
+ON CONFLICT DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'user_id', legacy.user_id,
+ 'openid', legacy.provider_user_id,
+ 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline',
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) = ''
+) AS legacy
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+END $$;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'synthetic_auth_identity:' || ai.id::text,
+ COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'auth_identity_id', ai.id,
+ 'user_id', ai.user_id,
+ 'provider_subject', ai.provider_subject,
+ 'reason', 'synthetic wechat auth identity still lacks unionid metadata and needs remediation',
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM auth_identities AS ai
+WHERE ai.provider_type = 'wechat'
+ AND COALESCE(ai.metadata ->> 'backfill_source', '') = 'synthetic_email'
+ AND BTRIM(COALESCE(ai.metadata ->> 'unionid', '')) = ''
+ON CONFLICT (report_type, report_key) DO NOTHING;
+
+DROP FUNCTION IF EXISTS public.__migration_115_safe_legacy_metadata_jsonb(TEXT);
diff --git a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
new file mode 100644
index 00000000..81eb133c
--- /dev/null
+++ b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
@@ -0,0 +1,525 @@
+CREATE OR REPLACE FUNCTION public.__migration_116_safe_legacy_metadata_jsonb(input_text TEXT)
+RETURNS JSONB
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN '{}'::jsonb;
+ END IF;
+
+ BEGIN
+ parsed := input_text::jsonb;
+ EXCEPTION
+ WHEN OTHERS THEN
+ RETURN '{}'::jsonb;
+ END;
+
+ IF jsonb_typeof(parsed) = 'object' THEN
+ RETURN parsed;
+ END IF;
+
+ RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed);
+END;
+$$;
+
+CREATE OR REPLACE FUNCTION public.__migration_116_is_valid_legacy_metadata_jsonb(input_text TEXT)
+RETURNS BOOLEAN
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN TRUE;
+ END IF;
+
+ parsed := input_text::jsonb;
+ RETURN TRUE;
+EXCEPTION
+ WHEN OTHERS THEN
+ RETURN FALSE;
+END;
+$$;
+
+DO $$
+BEGIN
+ IF to_regclass('public.user_external_identities') IS NULL THEN
+ RETURN;
+ END IF;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_invalid_metadata_json',
+ 'legacy_external_identity:' || uei.id::text,
+ jsonb_build_object(
+ 'legacy_identity_id', uei.id,
+ 'user_id', uei.user_id,
+ 'provider', LOWER(BTRIM(COALESCE(uei.provider, ''))),
+ 'provider_user_id', BTRIM(COALESCE(uei.provider_user_id, '')),
+ 'provider_union_id', BTRIM(COALESCE(uei.provider_union_id, '')),
+ 'reason', 'legacy metadata is not valid JSON; migration downgraded metadata to empty object',
+ 'raw_metadata', LEFT(BTRIM(COALESCE(uei.metadata, '')), 1000),
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM user_external_identities AS uei
+JOIN users AS u ON u.id = uei.user_id
+WHERE u.deleted_at IS NULL
+ AND BTRIM(COALESCE(uei.metadata, '')) <> ''
+ AND NOT public.__migration_116_is_valid_legacy_metadata_jsonb(uei.metadata)
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'provider_type', legacy.provider_type,
+ 'provider_key', legacy.provider_key,
+ 'provider_subject', legacy.provider_subject,
+ 'conflicting_legacy_user_ids', ambiguous.conflicting_legacy_user_ids,
+ 'reason', 'legacy canonical identity subject belongs to multiple legacy users and cannot be auto-resolved',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+) AS legacy
+JOIN (
+ SELECT
+ provider_type,
+ provider_key,
+ provider_subject,
+ to_jsonb(array_agg(DISTINCT user_id ORDER BY user_id)) AS conflicting_legacy_user_ids
+ FROM (
+ SELECT
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+ ) AS legacy_subjects
+ GROUP BY provider_type, provider_key, provider_subject
+ HAVING COUNT(DISTINCT user_id) > 1
+) AS ambiguous
+ ON ambiguous.provider_type = legacy.provider_type
+ AND ambiguous.provider_key = legacy.provider_key
+ AND ambiguous.provider_subject = legacy.provider_subject
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'existing_identity_id', ai.id,
+ 'existing_user_id', ai.user_id,
+ 'provider_type', legacy.provider_type,
+ 'provider_key', legacy.provider_key,
+ 'provider_subject', legacy.provider_subject,
+ 'reason', 'legacy canonical identity subject already belongs to another user',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ BTRIM(COALESCE(uei.provider_username, '')) AS provider_username,
+ BTRIM(COALESCE(uei.display_name, '')) AS display_name,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+) AS legacy
+JOIN (
+ SELECT
+ provider_type,
+ provider_key,
+ provider_subject
+ FROM (
+ SELECT
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+ ) AS legacy_subjects
+ GROUP BY provider_type, provider_key, provider_subject
+ HAVING COUNT(DISTINCT user_id) = 1
+) AS clear_subjects
+ ON clear_subjects.provider_type = legacy.provider_type
+ AND clear_subjects.provider_key = legacy.provider_key
+ AND clear_subjects.provider_subject = legacy.provider_subject
+JOIN auth_identities AS ai
+ ON ai.provider_type = legacy.provider_type
+ AND ai.provider_key = legacy.provider_key
+ AND ai.provider_subject = legacy.provider_subject
+WHERE ai.user_id <> legacy.user_id
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ BTRIM(COALESCE(uei.provider_username, '')) AS provider_username,
+ BTRIM(COALESCE(uei.display_name, '')) AS display_name,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ COALESCE(uei.updated_at, uei.created_at, NOW()) AS verified_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+),
+clear_subjects AS (
+ SELECT
+ provider_type,
+ provider_key,
+ provider_subject
+ FROM legacy
+ GROUP BY provider_type, provider_key, provider_subject
+ HAVING COUNT(DISTINCT user_id) = 1
+),
+canonical_legacy AS (
+ SELECT
+ legacy.*,
+ ROW_NUMBER() OVER (
+ PARTITION BY legacy.provider_type, legacy.provider_key, legacy.provider_subject
+ ORDER BY legacy.verified_at DESC, legacy.id DESC
+ ) AS canonical_row_num
+ FROM legacy
+ JOIN clear_subjects
+ ON clear_subjects.provider_type = legacy.provider_type
+ AND clear_subjects.provider_key = legacy.provider_key
+ AND clear_subjects.provider_subject = legacy.provider_subject
+)
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ legacy.provider_type,
+ legacy.provider_key,
+ legacy.provider_subject,
+ legacy.verified_at,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_union_id', NULLIF(legacy.provider_union_id, ''),
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM canonical_legacy AS legacy
+LEFT JOIN auth_identities AS ai
+ ON ai.provider_type = legacy.provider_type
+ AND ai.provider_key = legacy.provider_key
+ AND ai.provider_subject = legacy.provider_subject
+WHERE legacy.canonical_row_num = 1
+ AND ai.id IS NULL
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_channel_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'existing_channel_id', channel.id,
+ 'existing_identity_id', existing_ai.id,
+ 'existing_user_id', existing_ai.user_id,
+ 'provider_type', 'wechat',
+ 'provider_key', 'wechat-main',
+ 'provider_subject', legacy.provider_union_id,
+ 'channel', legacy.channel,
+ 'channel_app_id', legacy.channel_app_id,
+ 'channel_subject', legacy.provider_user_id,
+ 'reason', 'legacy channel subject already belongs to another user',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
+ ''
+ )) AS channel_app_id
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+) AS legacy
+JOIN (
+ SELECT
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_subject
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ GROUP BY BTRIM(COALESCE(uei.provider_union_id, ''))
+ HAVING COUNT(DISTINCT uei.user_id) = 1
+) AS clear_subjects
+ ON clear_subjects.provider_subject = legacy.provider_union_id
+JOIN auth_identities AS legacy_ai
+ ON legacy_ai.user_id = legacy.user_id
+ AND legacy_ai.provider_type = 'wechat'
+ AND legacy_ai.provider_key = 'wechat-main'
+ AND legacy_ai.provider_subject = legacy.provider_union_id
+JOIN auth_identity_channels AS channel
+ ON channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = legacy.channel
+ AND channel.channel_app_id = legacy.channel_app_id
+ AND channel.channel_subject = legacy.provider_user_id
+JOIN auth_identities AS existing_ai
+ ON existing_ai.id = channel.identity_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND existing_ai.user_id <> legacy.user_id
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
+ ''
+ )) AS channel_app_id
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+),
+clear_subjects AS (
+ SELECT
+ provider_union_id AS provider_subject
+ FROM legacy
+ GROUP BY provider_union_id
+ HAVING COUNT(DISTINCT user_id) = 1
+)
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+SELECT
+ legacy_ai.id,
+ 'wechat',
+ 'wechat-main',
+ legacy.channel,
+ legacy.channel_app_id,
+ legacy.provider_user_id,
+ legacy.metadata_json || jsonb_build_object(
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM legacy
+JOIN clear_subjects
+ ON clear_subjects.provider_subject = legacy.provider_union_id
+JOIN auth_identities AS legacy_ai
+ ON legacy_ai.user_id = legacy.user_id
+ AND legacy_ai.provider_type = 'wechat'
+ AND legacy_ai.provider_key = 'wechat-main'
+ AND legacy_ai.provider_subject = legacy.provider_union_id
+LEFT JOIN auth_identity_channels AS channel
+ ON channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = legacy.channel
+ AND channel.channel_app_id = legacy.channel_app_id
+ AND channel.channel_subject = legacy.provider_user_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND channel.id IS NULL
+ON CONFLICT DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'user_id', legacy.user_id,
+ 'openid', legacy.provider_user_id,
+ 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) = ''
+) AS legacy
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+END $$;
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'auth_identities_metadata_is_object_check'
+ ) THEN
+ ALTER TABLE auth_identities
+ ADD CONSTRAINT auth_identities_metadata_is_object_check
+ CHECK (jsonb_typeof(metadata) = 'object');
+ END IF;
+
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'auth_identity_channels_metadata_is_object_check'
+ ) THEN
+ ALTER TABLE auth_identity_channels
+ ADD CONSTRAINT auth_identity_channels_metadata_is_object_check
+ CHECK (jsonb_typeof(metadata) = 'object');
+ END IF;
+
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'auth_identity_migration_reports_details_is_object_check'
+ ) THEN
+ ALTER TABLE auth_identity_migration_reports
+ ADD CONSTRAINT auth_identity_migration_reports_details_is_object_check
+ CHECK (jsonb_typeof(details) = 'object');
+ END IF;
+END $$;
+
+DROP FUNCTION IF EXISTS public.__migration_116_is_valid_legacy_metadata_jsonb(TEXT);
+DROP FUNCTION IF EXISTS public.__migration_116_safe_legacy_metadata_jsonb(TEXT);
diff --git a/backend/migrations/117_add_payment_order_provider_snapshot.sql b/backend/migrations/117_add_payment_order_provider_snapshot.sql
new file mode 100644
index 00000000..56a5fe2d
--- /dev/null
+++ b/backend/migrations/117_add_payment_order_provider_snapshot.sql
@@ -0,0 +1,2 @@
+ALTER TABLE payment_orders
+ADD COLUMN IF NOT EXISTS provider_snapshot JSONB;
diff --git a/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql
new file mode 100644
index 00000000..18782617
--- /dev/null
+++ b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql
@@ -0,0 +1,25 @@
+INSERT INTO settings (key, value)
+VALUES
+ (
+ 'wechat_connect_open_enabled',
+ CASE
+ WHEN NOT EXISTS (SELECT 1 FROM settings WHERE key = 'wechat_connect_enabled') THEN ''
+ WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false'
+ WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'false'
+ ELSE 'true'
+ END
+ ),
+ (
+ 'wechat_connect_mp_enabled',
+ CASE
+ WHEN NOT EXISTS (SELECT 1 FROM settings WHERE key = 'wechat_connect_enabled') THEN ''
+ WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false'
+ WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'true'
+ ELSE 'false'
+ END
+ ),
+ ('auth_source_default_email_grant_on_signup', 'false'),
+ ('auth_source_default_linuxdo_grant_on_signup', 'false'),
+ ('auth_source_default_oidc_grant_on_signup', 'false'),
+ ('auth_source_default_wechat_grant_on_signup', 'false')
+ON CONFLICT (key) DO NOTHING;
diff --git a/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql b/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql
new file mode 100644
index 00000000..15e2c15f
--- /dev/null
+++ b/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql
@@ -0,0 +1,6 @@
+-- Intentionally left as a no-op.
+-- The online index rollout lives in 120_enforce_payment_orders_out_trade_no_unique_notx.sql
+DO $$
+BEGIN
+ NULL;
+END $$;
diff --git a/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql
new file mode 100644
index 00000000..638d8622
--- /dev/null
+++ b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql
@@ -0,0 +1,10 @@
+-- Build the payment order uniqueness guarantee online.
+-- The migration runner performs an explicit duplicate out_trade_no precheck and
+-- drops any stale invalid paymentorder_out_trade_no_unique index before retrying.
+-- Create the new partial unique index concurrently first so writes keep flowing,
+-- then remove the legacy index name once the replacement is ready.
+CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
+ ON payment_orders (out_trade_no)
+ WHERE out_trade_no <> '';
+
+DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
diff --git a/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql b/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql
new file mode 100644
index 00000000..ef2599dc
--- /dev/null
+++ b/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql
@@ -0,0 +1,22 @@
+DO $$
+BEGIN
+ IF EXISTS (
+ SELECT 1
+ FROM pg_indexes
+ WHERE schemaname = 'public'
+ AND tablename = 'payment_orders'
+ AND indexname = 'paymentorder_out_trade_no_unique'
+ ) THEN
+ IF EXISTS (
+ SELECT 1
+ FROM pg_indexes
+ WHERE schemaname = 'public'
+ AND tablename = 'payment_orders'
+ AND indexname = 'paymentorder_out_trade_no'
+ ) THEN
+ EXECUTE 'DROP INDEX IF EXISTS paymentorder_out_trade_no';
+ END IF;
+
+ EXECUTE 'ALTER INDEX paymentorder_out_trade_no_unique RENAME TO paymentorder_out_trade_no';
+ END IF;
+END $$;
diff --git a/backend/migrations/121_auth_identity_migration_report_type_widen.sql b/backend/migrations/121_auth_identity_migration_report_type_widen.sql
new file mode 100644
index 00000000..66bfb44a
--- /dev/null
+++ b/backend/migrations/121_auth_identity_migration_report_type_widen.sql
@@ -0,0 +1,2 @@
+ALTER TABLE auth_identity_migration_reports
+ALTER COLUMN report_type TYPE VARCHAR(80);
diff --git a/backend/migrations/122_pending_auth_completion_token_cleanup.sql b/backend/migrations/122_pending_auth_completion_token_cleanup.sql
new file mode 100644
index 00000000..e6341142
--- /dev/null
+++ b/backend/migrations/122_pending_auth_completion_token_cleanup.sql
@@ -0,0 +1,15 @@
+UPDATE pending_auth_sessions
+SET
+ local_flow_state = jsonb_set(
+ local_flow_state,
+ '{completion_response}',
+ ((local_flow_state -> 'completion_response') - 'access_token' - 'refresh_token' - 'expires_in' - 'token_type'),
+ true
+ )
+WHERE jsonb_typeof(local_flow_state -> 'completion_response') = 'object'
+ AND (
+ (local_flow_state -> 'completion_response') ? 'access_token'
+ OR (local_flow_state -> 'completion_response') ? 'refresh_token'
+ OR (local_flow_state -> 'completion_response') ? 'expires_in'
+ OR (local_flow_state -> 'completion_response') ? 'token_type'
+ );
diff --git a/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql
new file mode 100644
index 00000000..4388285a
--- /dev/null
+++ b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql
@@ -0,0 +1,68 @@
+-- Auto-backfill untouched migration 110 signup-grant defaults to the corrected false value.
+-- Rows still matching the migration-110 default payload and timestamp window are treated as
+-- untouched legacy defaults; any remaining legacy true values are reported for manual review.
+
+WITH migration_110 AS (
+ SELECT applied_at
+ FROM schema_migrations
+ WHERE filename = '110_pending_auth_and_provider_default_grants.sql'
+),
+providers AS (
+ SELECT provider_type
+ FROM (
+ VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat')
+ ) AS providers(provider_type)
+),
+legacy_provider_defaults AS (
+ SELECT providers.provider_type
+ FROM providers
+ CROSS JOIN migration_110
+ JOIN settings balance
+ ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance'
+ JOIN settings concurrency
+ ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency'
+ JOIN settings subscriptions
+ ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions'
+ JOIN settings grant_on_signup
+ ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
+ JOIN settings grant_on_first_bind
+ ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind'
+ WHERE balance.value = '0'
+ AND concurrency.value = '5'
+ AND subscriptions.value = '[]'
+ AND grant_on_signup.value = 'true'
+ AND grant_on_first_bind.value = 'false'
+ AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+ AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+ AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+ AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+ AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+),
+updated_signup_grants AS (
+ UPDATE settings
+ SET
+ value = 'false',
+ updated_at = NOW()
+ FROM legacy_provider_defaults
+ WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup'
+ AND settings.value = 'true'
+ RETURNING legacy_provider_defaults.provider_type
+)
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_auth_source_signup_grant_review',
+ providers.provider_type,
+ jsonb_build_object(
+ 'provider_type', providers.provider_type,
+ 'current_value', grant_on_signup.value,
+ 'auto_backfilled', FALSE,
+ 'reason', 'legacy_true_default_not_auto_backfilled'
+ )
+FROM providers
+JOIN settings grant_on_signup
+ ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
+LEFT JOIN updated_signup_grants
+ ON updated_signup_grants.provider_type = providers.provider_type
+WHERE grant_on_signup.value = 'true'
+ AND updated_signup_grants.provider_type IS NULL
+ON CONFLICT (report_type, report_key) DO NOTHING;
diff --git a/backend/migrations/124_backfill_legacy_oidc_security_flags.sql b/backend/migrations/124_backfill_legacy_oidc_security_flags.sql
new file mode 100644
index 00000000..e68bb11a
--- /dev/null
+++ b/backend/migrations/124_backfill_legacy_oidc_security_flags.sql
@@ -0,0 +1,32 @@
+-- Preserve legacy OIDC behavior for upgraded installs that predate the
+-- introduction of secure PKCE/id_token defaults. Fresh installs continue to
+-- inherit runtime defaults when these rows are absent.
+
+WITH legacy_oidc_install AS (
+ SELECT 1
+ FROM settings
+ WHERE key IN (
+ 'oidc_connect_enabled',
+ 'oidc_connect_client_id',
+ 'oidc_connect_authorize_url',
+ 'oidc_connect_token_url',
+ 'oidc_connect_issuer_url',
+ 'oidc_connect_userinfo_url',
+ 'oidc_connect_frontend_redirect_url'
+ )
+ LIMIT 1
+)
+INSERT INTO settings (key, value)
+SELECT defaults.key, 'false'
+FROM legacy_oidc_install
+CROSS JOIN (
+ VALUES
+ ('oidc_connect_use_pkce'),
+ ('oidc_connect_validate_id_token')
+) AS defaults(key)
+WHERE NOT EXISTS (
+ SELECT 1
+ FROM settings existing
+ WHERE existing.key = defaults.key
+)
+ON CONFLICT (key) DO NOTHING;
diff --git a/backend/migrations/125_add_channel_monitors.sql b/backend/migrations/125_add_channel_monitors.sql
new file mode 100644
index 00000000..5ec327da
--- /dev/null
+++ b/backend/migrations/125_add_channel_monitors.sql
@@ -0,0 +1,58 @@
+-- Migration: 125_add_channel_monitors
+-- 渠道监控 MVP:周期性对外部 provider/endpoint/api_key 做模型心跳测试。
+--
+-- 表结构说明:
+-- - channel_monitors 渠道配置表(一行 = 一个监控对象)
+-- - channel_monitor_histories 检测历史明细表(一次检测一个模型 = 一行)
+--
+-- 设计要点:
+-- - api_key_encrypted 列存放 AES-256-GCM 密文(base64),由 service 层加密。
+-- - extra_models 用 JSONB 存储字符串数组,便于扩展(后续可加权重等元数据)。
+-- - history 表通过 ON DELETE CASCADE 自动清理已删除监控的历史。
+-- - (enabled, last_checked_at) 索引服务于调度器扫描“到期需要检测”的监控。
+-- - histories 上 (monitor_id, model, checked_at DESC) 服务用户视图聚合查询;
+-- 单独的 (checked_at) 索引服务定期清理 30 天前数据的 DELETE。
+
+CREATE TABLE IF NOT EXISTS channel_monitors (
+ id BIGSERIAL PRIMARY KEY,
+ name VARCHAR(100) NOT NULL,
+ provider VARCHAR(20) NOT NULL, -- openai / anthropic / gemini
+ endpoint VARCHAR(500) NOT NULL, -- base origin
+ api_key_encrypted TEXT NOT NULL, -- AES-256-GCM (base64)
+ primary_model VARCHAR(200) NOT NULL,
+ extra_models JSONB NOT NULL DEFAULT '[]'::jsonb,
+ group_name VARCHAR(100) NOT NULL DEFAULT '',
+ enabled BOOLEAN NOT NULL DEFAULT TRUE,
+ interval_seconds INT NOT NULL,
+ last_checked_at TIMESTAMPTZ,
+ created_by BIGINT NOT NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT channel_monitors_provider_check CHECK (provider IN ('openai', 'anthropic', 'gemini')),
+ CONSTRAINT channel_monitors_interval_check CHECK (interval_seconds BETWEEN 15 AND 3600)
+);
+
+CREATE INDEX IF NOT EXISTS idx_channel_monitors_enabled_last_checked
+ ON channel_monitors (enabled, last_checked_at);
+CREATE INDEX IF NOT EXISTS idx_channel_monitors_provider
+ ON channel_monitors (provider);
+CREATE INDEX IF NOT EXISTS idx_channel_monitors_group_name
+ ON channel_monitors (group_name);
+
+CREATE TABLE IF NOT EXISTS channel_monitor_histories (
+ id BIGSERIAL PRIMARY KEY,
+ monitor_id BIGINT NOT NULL REFERENCES channel_monitors(id) ON DELETE CASCADE,
+ model VARCHAR(200) NOT NULL,
+ status VARCHAR(20) NOT NULL,
+ latency_ms INT,
+ ping_latency_ms INT,
+ message VARCHAR(500) NOT NULL DEFAULT '',
+ checked_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT channel_monitor_histories_status_check
+ CHECK (status IN ('operational', 'degraded', 'failed', 'error'))
+);
+
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_histories_monitor_model_checked
+ ON channel_monitor_histories (monitor_id, model, checked_at DESC);
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_histories_checked_at
+ ON channel_monitor_histories (checked_at);
diff --git a/backend/migrations/125_add_group_rpm_limit.sql b/backend/migrations/125_add_group_rpm_limit.sql
new file mode 100644
index 00000000..fbde1b20
--- /dev/null
+++ b/backend/migrations/125_add_group_rpm_limit.sql
@@ -0,0 +1,7 @@
+-- Add per-group Requests-Per-Minute limit.
+-- rpm_limit: 分组统一 RPM 上限(0 = 不限制)。
+-- 一旦配置即接管该用户在该分组的限流,覆盖用户级 users.rpm_limit。
+-- 计数键:rpm:ug:{user_id}:{group_id}:{minute}。
+ALTER TABLE groups ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0;
+
+COMMENT ON COLUMN groups.rpm_limit IS '分组 RPM 上限;0 表示不限制;设置后接管该分组用户的限流(覆盖用户级 rpm_limit)。';
diff --git a/backend/migrations/126_add_channel_monitor_aggregation.sql b/backend/migrations/126_add_channel_monitor_aggregation.sql
new file mode 100644
index 00000000..e643763c
--- /dev/null
+++ b/backend/migrations/126_add_channel_monitor_aggregation.sql
@@ -0,0 +1,60 @@
+-- Migration: 126_add_channel_monitor_aggregation
+-- 渠道监控日聚合:把 channel_monitor_histories 的明细按天聚合,明细只保留 1 天,
+-- 聚合保留 30 天。明细和聚合表都用软删除(deleted_at),由 ops cleanup 任务每天
+-- 凌晨随运维监控清理一起跑(共享 cron)。
+--
+-- 设计要点:
+-- - channel_monitor_histories 加 deleted_at 软删除字段(SoftDeleteMixin 全局
+-- Hook 会把 DELETE 自动改写成 UPDATE deleted_at = NOW())。
+-- - channel_monitor_daily_rollups 按 (monitor_id, model, bucket_date) 唯一,
+-- 用 ON CONFLICT DO UPDATE 实现幂等回填,状态分布和延迟分子分母都保留,
+-- 方便后续按窗口任意求加权可用率和均值。
+-- - watermark 表只有一行(id=1),记录最近一次聚合到达的日期,避免重启后重复
+-- 扫全表。
+-- - rollup 上 (bucket_date) 索引服务清理任务的 DELETE WHERE bucket_date < cutoff。
+
+-- 1) 给历史明细表加软删除字段
+ALTER TABLE channel_monitor_histories
+ ADD COLUMN IF NOT EXISTS deleted_at TIMESTAMPTZ;
+
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_histories_deleted_at
+ ON channel_monitor_histories (deleted_at);
+
+-- 2) 创建日聚合表
+CREATE TABLE IF NOT EXISTS channel_monitor_daily_rollups (
+ id BIGSERIAL PRIMARY KEY,
+ monitor_id BIGINT NOT NULL REFERENCES channel_monitors(id) ON DELETE CASCADE,
+ model VARCHAR(200) NOT NULL,
+ bucket_date DATE NOT NULL,
+ total_checks INT NOT NULL DEFAULT 0,
+ ok_count INT NOT NULL DEFAULT 0,
+ operational_count INT NOT NULL DEFAULT 0,
+ degraded_count INT NOT NULL DEFAULT 0,
+ failed_count INT NOT NULL DEFAULT 0,
+ error_count INT NOT NULL DEFAULT 0,
+ sum_latency_ms BIGINT NOT NULL DEFAULT 0,
+ count_latency INT NOT NULL DEFAULT 0,
+ sum_ping_latency_ms BIGINT NOT NULL DEFAULT 0,
+ count_ping_latency INT NOT NULL DEFAULT 0,
+ computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ deleted_at TIMESTAMPTZ
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_monitor_daily_rollups_unique
+ ON channel_monitor_daily_rollups (monitor_id, model, bucket_date);
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_daily_rollups_bucket
+ ON channel_monitor_daily_rollups (bucket_date);
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_daily_rollups_deleted_at
+ ON channel_monitor_daily_rollups (deleted_at);
+
+-- 3) 创建 watermark 表(单行:id=1)
+CREATE TABLE IF NOT EXISTS channel_monitor_aggregation_watermark (
+ id INT PRIMARY KEY DEFAULT 1,
+ last_aggregated_date DATE,
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT channel_monitor_aggregation_watermark_singleton CHECK (id = 1)
+);
+
+INSERT INTO channel_monitor_aggregation_watermark (id, last_aggregated_date, updated_at)
+VALUES (1, NULL, NOW())
+ON CONFLICT (id) DO NOTHING;
diff --git a/backend/migrations/126_add_user_rpm_limit.sql b/backend/migrations/126_add_user_rpm_limit.sql
new file mode 100644
index 00000000..64a8b977
--- /dev/null
+++ b/backend/migrations/126_add_user_rpm_limit.sql
@@ -0,0 +1,7 @@
+-- Add per-user Requests-Per-Minute cap.
+-- rpm_limit: 用户全局 RPM 兜底(0 = 不限制)。
+-- 仅当所访问分组未设置 rpm_limit 且无 user-group rpm_override 时作为兜底生效。
+-- 计数键:rpm:u:{user_id}:{minute}。
+ALTER TABLE users ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0;
+
+COMMENT ON COLUMN users.rpm_limit IS '用户级 RPM 兜底上限;0 表示不限制;仅当分组未设置 rpm_limit 时生效。';
diff --git a/backend/migrations/127_add_user_group_rpm_override.sql b/backend/migrations/127_add_user_group_rpm_override.sql
new file mode 100644
index 00000000..1d674258
--- /dev/null
+++ b/backend/migrations/127_add_user_group_rpm_override.sql
@@ -0,0 +1,16 @@
+-- 在已有的"用户专属分组倍率表"上扩展 rpm_override 列;同时放宽 rate_multiplier 为可空,
+-- 使一行记录可以只覆盖 rate、只覆盖 rpm,或同时覆盖两者。
+-- 语义:
+-- - rate_multiplier NULL → 该用户在此分组使用 groups.rate_multiplier 默认值
+-- - rate_multiplier 非 NULL → 覆盖分组默认计费倍率
+-- - rpm_override NULL → 该用户在此分组使用 groups.rpm_limit 默认值
+-- - rpm_override 非 NULL → 覆盖分组默认 RPM(0 = 不限制)
+-- 用户级 users.rpm_limit 仍独立生效(跨分组总配额)。
+ALTER TABLE user_group_rate_multipliers
+ ADD COLUMN IF NOT EXISTS rpm_override integer NULL;
+
+ALTER TABLE user_group_rate_multipliers
+ ALTER COLUMN rate_multiplier DROP NOT NULL;
+
+COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率;NULL 表示沿用分组默认倍率。';
+COMMENT ON COLUMN user_group_rate_multipliers.rpm_override IS '专属 RPM 上限;NULL 表示沿用分组默认;0 表示该用户在此分组不受 RPM 限制。';
diff --git a/backend/migrations/127_drop_channel_monitor_deleted_at.sql b/backend/migrations/127_drop_channel_monitor_deleted_at.sql
new file mode 100644
index 00000000..2260f06b
--- /dev/null
+++ b/backend/migrations/127_drop_channel_monitor_deleted_at.sql
@@ -0,0 +1,16 @@
+-- Migration: 127_drop_channel_monitor_deleted_at
+-- 纠正 110 引入的 SoftDeleteMixin:日志/聚合表无恢复需求,软删会让行和索引只增不减,
+-- 徒增磁盘和查询开销。改回分批物理删(由 OpsCleanupService 每天凌晨统一调度,
+-- deleteOldRowsByID 模板,batch=5000)。
+--
+-- 110 尚未跑过聚合/清理(首次 maintenance 在次日 02:00),所以此处不担心业务数据。
+-- 直接 DROP 列 + 索引;对应的 Go 侧 ent schema 已移除 SoftDeleteMixin、repo 的
+-- raw SQL 已移除 deleted_at IS NULL 过滤。
+
+DROP INDEX IF EXISTS idx_channel_monitor_histories_deleted_at;
+ALTER TABLE channel_monitor_histories
+ DROP COLUMN IF EXISTS deleted_at;
+
+DROP INDEX IF EXISTS idx_channel_monitor_daily_rollups_deleted_at;
+ALTER TABLE channel_monitor_daily_rollups
+ DROP COLUMN IF EXISTS deleted_at;
diff --git a/backend/migrations/128_add_channel_monitor_request_templates.sql b/backend/migrations/128_add_channel_monitor_request_templates.sql
new file mode 100644
index 00000000..2db8fef6
--- /dev/null
+++ b/backend/migrations/128_add_channel_monitor_request_templates.sql
@@ -0,0 +1,70 @@
+-- Migration: 128_add_channel_monitor_request_templates
+-- 加请求模板表 + 给 channel_monitors 加 4 个快照字段(template_id 关联引用 + extra_headers /
+-- body_override_mode / body_override 三个真正运行时使用的快照)。
+--
+-- 设计要点:
+-- 1) 模板与监控之间是「应用即拷贝」的快照语义,运行时 checker 不再回查模板表。
+-- 模板 UPDATE 不会自动影响监控;只有用户主动「应用到关联监控」才会刷新快照。
+-- 2) ON DELETE SET NULL:模板删除不级联清理监控;监控保留快照继续工作。
+-- 3) extra_headers / body_override 都是 JSONB;body_override_mode 用 varchar(不是 enum)
+-- 便于将来加新模式无需 ALTER TYPE。
+-- 4) 同一 provider 内模板 name 唯一(允许 Anthropic + OpenAI 重名 "伪装官方客户端")。
+
+CREATE TABLE IF NOT EXISTS channel_monitor_request_templates (
+ id BIGSERIAL PRIMARY KEY,
+ name VARCHAR(100) NOT NULL,
+ provider VARCHAR(20) NOT NULL,
+ description VARCHAR(500) NOT NULL DEFAULT '',
+ extra_headers JSONB NOT NULL DEFAULT '{}'::jsonb,
+ body_override_mode VARCHAR(10) NOT NULL DEFAULT 'off',
+ body_override JSONB NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT channel_monitor_request_templates_provider_check
+ CHECK (provider IN ('openai', 'anthropic', 'gemini')),
+ CONSTRAINT channel_monitor_request_templates_body_mode_check
+ CHECK (body_override_mode IN ('off', 'merge', 'replace'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS channel_monitor_request_templates_provider_name
+ ON channel_monitor_request_templates (provider, name);
+
+-- channel_monitors 加 4 列(ADD COLUMN IF NOT EXISTS 需要 PG 9.6+,生产使用 PG 16)
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS template_id BIGINT NULL;
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS extra_headers JSONB NOT NULL DEFAULT '{}'::jsonb;
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS body_override_mode VARCHAR(10) NOT NULL DEFAULT 'off';
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS body_override JSONB NULL;
+
+-- 约束 + 外键(DO 块里 IF NOT EXISTS 判断,保证幂等)
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1 FROM information_schema.table_constraints
+ WHERE constraint_name = 'channel_monitors_body_mode_check'
+ AND table_name = 'channel_monitors'
+ ) THEN
+ ALTER TABLE channel_monitors
+ ADD CONSTRAINT channel_monitors_body_mode_check
+ CHECK (body_override_mode IN ('off', 'merge', 'replace'));
+ END IF;
+
+ IF NOT EXISTS (
+ SELECT 1 FROM information_schema.table_constraints
+ WHERE constraint_name = 'channel_monitors_template_id_fkey'
+ AND table_name = 'channel_monitors'
+ ) THEN
+ ALTER TABLE channel_monitors
+ ADD CONSTRAINT channel_monitors_template_id_fkey
+ FOREIGN KEY (template_id)
+ REFERENCES channel_monitor_request_templates (id)
+ ON DELETE SET NULL;
+ END IF;
+END $$;
+
+CREATE INDEX IF NOT EXISTS idx_channel_monitors_template_id
+ ON channel_monitors (template_id)
+ WHERE template_id IS NOT NULL;
diff --git a/backend/migrations/129_seed_claude_code_template.sql b/backend/migrations/129_seed_claude_code_template.sql
new file mode 100644
index 00000000..d9b062c9
--- /dev/null
+++ b/backend/migrations/129_seed_claude_code_template.sql
@@ -0,0 +1,38 @@
+-- Migration: 129_seed_claude_code_template
+-- 内置「Claude Code 伪装」请求模板,覆盖 Anthropic 上游对官方 CLI 客户端的所有验证项:
+-- 1) User-Agent / X-App / anthropic-beta / anthropic-version 等头
+-- 2) system 数组首项与官方 system prompt 字面一致(Dice >= 0.5)
+-- 3) metadata.user_id 满足 ParseMetadataUserID — 这里用 legacy 格式(user_<64hex>_account__session_<36char>)
+-- 避免新版 JSON 字符串内嵌 JSON 在编辑器里出现一长串 \" 转义,便于用户阅读。
+--
+-- ON CONFLICT DO NOTHING:已部署环境(手动建过模板)跑此 migration 不会重复 / 覆盖。
+-- 用户可自行编辑后续覆盖此 seed;CC 升大版时再起一条 migration 提供新模板,不动用户的旧模板。
+
+INSERT INTO channel_monitor_request_templates (
+ name, provider, description, extra_headers, body_override_mode, body_override
+)
+VALUES (
+ 'Claude Code 伪装',
+ 'anthropic',
+ '完整模拟 Claude Code 2.1.114 客户端:UA + anthropic-beta + system + metadata.user_id 全部对齐,绕过 Anthropic 上游 ''Claude Code only'' 限制(如 Max 套餐)。',
+ '{
+ "User-Agent": "claude-cli/2.1.114 (external, sdk-cli)",
+ "X-App": "cli",
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "claude-code-20250219,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05,advisor-tool-2026-03-01",
+ "anthropic-dangerous-direct-browser-access": "true"
+ }'::jsonb,
+ 'merge',
+ '{
+ "system": [
+ {
+ "type": "text",
+ "text": "You are Claude Code, Anthropic''s official CLI for Claude."
+ }
+ ],
+ "metadata": {
+ "user_id": "user_0000000000000000000000000000000000000000000000000000000000000000_account_00000000-0000-0000-0000-000000000000_session_00000000-0000-0000-0000-000000000000"
+ }
+ }'::jsonb
+)
+ON CONFLICT (provider, name) DO NOTHING;
diff --git a/backend/migrations/130_add_user_affiliates.sql b/backend/migrations/130_add_user_affiliates.sql
new file mode 100644
index 00000000..d8c001e0
--- /dev/null
+++ b/backend/migrations/130_add_user_affiliates.sql
@@ -0,0 +1,20 @@
+CREATE TABLE IF NOT EXISTS user_affiliates (
+ user_id BIGINT PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE,
+ aff_code VARCHAR(32) NOT NULL UNIQUE,
+ inviter_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
+ aff_count INTEGER NOT NULL DEFAULT 0,
+ aff_quota DECIMAL(20,8) NOT NULL DEFAULT 0,
+ aff_history_quota DECIMAL(20,8) NOT NULL DEFAULT 0,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS idx_user_affiliates_inviter_id ON user_affiliates(inviter_id);
+CREATE INDEX IF NOT EXISTS idx_user_affiliates_aff_quota ON user_affiliates(aff_quota);
+
+COMMENT ON TABLE user_affiliates IS '用户邀请返利信息';
+COMMENT ON COLUMN user_affiliates.aff_code IS '用户邀请代码';
+COMMENT ON COLUMN user_affiliates.inviter_id IS '邀请人用户ID';
+COMMENT ON COLUMN user_affiliates.aff_count IS '累计邀请人数';
+COMMENT ON COLUMN user_affiliates.aff_quota IS '当前可提取返利金额';
+COMMENT ON COLUMN user_affiliates.aff_history_quota IS '累计返利历史金额';
diff --git a/backend/migrations/131_affiliate_rebate_hardening.sql b/backend/migrations/131_affiliate_rebate_hardening.sql
new file mode 100644
index 00000000..81e37a9e
--- /dev/null
+++ b/backend/migrations/131_affiliate_rebate_hardening.sql
@@ -0,0 +1,58 @@
+-- 1) Normalize historical affiliate rebate rate values.
+-- Legacy compatibility treated 0 20%).
+-- We now use pure percentage semantics, so convert persisted fractional values once.
+UPDATE settings
+SET value = to_char((value::numeric * 100), 'FM999999990.########'),
+ updated_at = NOW()
+WHERE key = 'affiliate_rebate_rate'
+ AND value ~ '^-?[0-9]+(\\.[0-9]+)?$'
+ AND value::numeric > 0
+ AND value::numeric <= 1;
+
+-- 2) Affiliate ledger for accrual/transfer traceability.
+CREATE TABLE IF NOT EXISTS user_affiliate_ledger (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ action VARCHAR(32) NOT NULL,
+ amount DECIMAL(20,8) NOT NULL,
+ source_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_user_id ON user_affiliate_ledger(user_id);
+CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_action ON user_affiliate_ledger(action);
+
+COMMENT ON TABLE user_affiliate_ledger IS '邀请返利资金流水(累计/转入)';
+COMMENT ON COLUMN user_affiliate_ledger.action IS 'accrue|transfer';
+
+-- 3) Enforce idempotency at DB layer for payment audit actions.
+WITH ranked AS (
+ SELECT id,
+ ROW_NUMBER() OVER (PARTITION BY order_id, action ORDER BY id) AS rn
+ FROM payment_audit_logs
+)
+DELETE FROM payment_audit_logs p
+USING ranked r
+WHERE p.id = r.id
+ AND r.rn > 1;
+
+CREATE UNIQUE INDEX IF NOT EXISTS idx_payment_audit_logs_order_action_uniq
+ON payment_audit_logs(order_id, action);
+
+-- 4) Prevent retroactive affiliate rebate issuance for legacy completed balance orders.
+INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
+SELECT po.id::text,
+ 'AFFILIATE_REBATE_SKIPPED',
+ '{"reason":"baseline before affiliate rebate idempotency rollout"}',
+ 'system',
+ NOW()
+FROM payment_orders po
+WHERE po.order_type = 'balance'
+ AND po.status = 'COMPLETED'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM payment_audit_logs pal
+ WHERE pal.order_id = po.id::text
+ AND pal.action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
+ );
diff --git a/backend/migrations/132_affiliate_custom_settings.sql b/backend/migrations/132_affiliate_custom_settings.sql
new file mode 100644
index 00000000..840fe8e0
--- /dev/null
+++ b/backend/migrations/132_affiliate_custom_settings.sql
@@ -0,0 +1,16 @@
+-- 邀请返利:用户专属配置增强
+-- 1) aff_rebate_rate_percent: 用户作为邀请人时的专属返利比例(百分比,NULL 表示沿用全局比例)
+-- 2) aff_code_custom: 标记当前 aff_code 是否被管理员手动改写过(用于"专属用户"列表筛选)
+
+ALTER TABLE user_affiliates
+ ADD COLUMN IF NOT EXISTS aff_rebate_rate_percent DECIMAL(5,2);
+
+ALTER TABLE user_affiliates
+ ADD COLUMN IF NOT EXISTS aff_code_custom BOOLEAN NOT NULL DEFAULT false;
+
+CREATE INDEX IF NOT EXISTS idx_user_affiliates_admin_settings
+ ON user_affiliates (updated_at)
+ WHERE aff_code_custom = true OR aff_rebate_rate_percent IS NOT NULL;
+
+COMMENT ON COLUMN user_affiliates.aff_rebate_rate_percent IS '专属返利比例(百分比 0-100,NULL 表示沿用全局)';
+COMMENT ON COLUMN user_affiliates.aff_code_custom IS '邀请码是否由管理员改写过(用于专属用户筛选)';
diff --git a/backend/migrations/133_affiliate_rebate_freeze.sql b/backend/migrations/133_affiliate_rebate_freeze.sql
new file mode 100644
index 00000000..b87d59b7
--- /dev/null
+++ b/backend/migrations/133_affiliate_rebate_freeze.sql
@@ -0,0 +1,17 @@
+-- 1) Add frozen quota column to user_affiliates for rebate freeze period.
+ALTER TABLE user_affiliates
+ ADD COLUMN IF NOT EXISTS aff_frozen_quota DECIMAL(20,8) NOT NULL DEFAULT 0;
+
+COMMENT ON COLUMN user_affiliates.aff_frozen_quota IS 'Rebate quota currently frozen (pending thaw after freeze period)';
+
+-- 2) Add frozen_until column to user_affiliate_ledger for per-entry freeze tracking.
+-- NULL = no freeze (or already thawed); non-NULL = frozen until this timestamp.
+ALTER TABLE user_affiliate_ledger
+ ADD COLUMN IF NOT EXISTS frozen_until TIMESTAMPTZ NULL;
+
+COMMENT ON COLUMN user_affiliate_ledger.frozen_until IS 'Rebate frozen until this time; NULL means already thawed or never frozen';
+
+-- 3) Partial index for efficient thaw queries (only rows still frozen).
+CREATE INDEX IF NOT EXISTS idx_ual_frozen_thaw
+ ON user_affiliate_ledger (user_id, frozen_until)
+ WHERE frozen_until IS NOT NULL;
diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go
new file mode 100644
index 00000000..798ae0fe
--- /dev/null
+++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go
@@ -0,0 +1,129 @@
+package migrations
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestMigration112UsesIdempotentAddColumn(t *testing.T) {
+ content, err := FS.ReadFile("112_add_payment_order_provider_key_snapshot.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS provider_key VARCHAR(30)")
+ require.NotContains(t, sql, "ADD COLUMN provider_key VARCHAR(30);")
+}
+
+func TestMigration118DoesNotForceOverwriteAuthSourceGrantDefaults(t *testing.T) {
+ content, err := FS.ReadFile("118_wechat_dual_mode_and_auth_source_defaults.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.NotContains(t, sql, "UPDATE settings")
+ require.NotContains(t, sql, "SET value = 'false'")
+ require.True(t, strings.Contains(sql, "ON CONFLICT (key) DO NOTHING"))
+ require.Contains(t, sql, "THEN ''")
+}
+
+func TestAuthIdentityReportTypeWideningRunsBeforeLongReportWritersAndStillReconcilesAt121(t *testing.T) {
+ preflightContent, err := FS.ReadFile("108a_widen_auth_identity_migration_report_type.sql")
+ require.NoError(t, err)
+
+ preflightSQL := string(preflightContent)
+ require.Contains(t, preflightSQL, "ALTER TABLE auth_identity_migration_reports")
+ require.Contains(t, preflightSQL, "ALTER COLUMN report_type TYPE VARCHAR(80)")
+
+ content, err := FS.ReadFile("109_auth_identity_compat_backfill.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.NotContains(t, sql, "ALTER TABLE auth_identity_migration_reports")
+
+ followupContent, err := FS.ReadFile("121_auth_identity_migration_report_type_widen.sql")
+ require.NoError(t, err)
+
+ followupSQL := string(followupContent)
+ require.Contains(t, followupSQL, "ALTER TABLE auth_identity_migration_reports")
+ require.Contains(t, followupSQL, "ALTER COLUMN report_type TYPE VARCHAR(80)")
+}
+
+func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) {
+ content, err := FS.ReadFile("119_enforce_payment_orders_out_trade_no_unique.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "120_enforce_payment_orders_out_trade_no_unique_notx.sql")
+ require.Contains(t, sql, "NULL;")
+ require.NotContains(t, sql, "CREATE UNIQUE INDEX")
+ require.NotContains(t, sql, "DROP INDEX")
+
+ followupContent, err := FS.ReadFile("120_enforce_payment_orders_out_trade_no_unique_notx.sql")
+ require.NoError(t, err)
+
+ followupSQL := string(followupContent)
+ require.Contains(t, followupSQL, "explicit duplicate out_trade_no precheck")
+ require.Contains(t, followupSQL, "stale invalid paymentorder_out_trade_no_unique index")
+ require.Contains(t, followupSQL, "CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique")
+ require.NotContains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique")
+ require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no")
+ require.Contains(t, followupSQL, "WHERE out_trade_no <> ''")
+
+ alignmentContent, err := FS.ReadFile("120a_align_payment_orders_out_trade_no_index_name.sql")
+ require.NoError(t, err)
+
+ alignmentSQL := string(alignmentContent)
+ require.Contains(t, alignmentSQL, "paymentorder_out_trade_no_unique")
+ require.Contains(t, alignmentSQL, "RENAME TO paymentorder_out_trade_no")
+}
+
+func TestMigration110SeedsAuthSourceSignupGrantsDisabledByDefault(t *testing.T) {
+ content, err := FS.ReadFile("110_pending_auth_and_provider_default_grants.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "('auth_source_default_email_grant_on_signup', 'false')")
+ require.Contains(t, sql, "('auth_source_default_linuxdo_grant_on_signup', 'false')")
+ require.Contains(t, sql, "('auth_source_default_oidc_grant_on_signup', 'false')")
+ require.Contains(t, sql, "('auth_source_default_wechat_grant_on_signup', 'false')")
+ require.NotContains(t, sql, "('auth_source_default_email_grant_on_signup', 'true')")
+}
+
+func TestMigration122ScrubsPendingOAuthCompletionTokensAtRest(t *testing.T) {
+ content, err := FS.ReadFile("122_pending_auth_completion_token_cleanup.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "UPDATE pending_auth_sessions")
+ require.Contains(t, sql, "completion_response")
+ require.Contains(t, sql, "access_token")
+ require.Contains(t, sql, "refresh_token")
+ require.Contains(t, sql, "expires_in")
+ require.Contains(t, sql, "token_type")
+}
+
+func TestMigration123BackfillsLegacyAuthSourceGrantDefaultsSafely(t *testing.T) {
+ content, err := FS.ReadFile("123_fix_legacy_auth_source_grant_on_signup_defaults.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "110_pending_auth_and_provider_default_grants.sql")
+ require.Contains(t, sql, "schema_migrations")
+ require.Contains(t, sql, "updated_at")
+ require.Contains(t, sql, "'_grant_on_signup'")
+ require.Contains(t, sql, "value = 'false'")
+ require.Contains(t, sql, "auth_identity_migration_reports")
+}
+
+func TestMigration124BackfillsLegacyOIDCSecurityFlagsSafely(t *testing.T) {
+ content, err := FS.ReadFile("124_backfill_legacy_oidc_security_flags.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "oidc_connect_use_pkce")
+ require.Contains(t, sql, "oidc_connect_validate_id_token")
+ require.Contains(t, sql, "ON CONFLICT (key) DO NOTHING")
+ require.Contains(t, sql, "oidc_connect_enabled")
+ require.Contains(t, sql, "'false'")
+}
diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml
index 358f6a31..dfc363b5 100644
--- a/deploy/config.example.yaml
+++ b/deploy/config.example.yaml
@@ -841,7 +841,7 @@ linuxdo_connect:
frontend_redirect_url: "/auth/linuxdo/callback"
token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none
# 注意:当 token_auth_method=none(public client)时,必须启用 PKCE
- use_pkce: false
+ use_pkce: true
userinfo_email_path: ""
userinfo_id_path: ""
userinfo_username_path: ""
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
index 755b313a..af93fa7e 100644
--- a/docs/PAYMENT.md
+++ b/docs/PAYMENT.md
@@ -22,11 +22,11 @@ Sub2API has a built-in payment system that enables user self-service top-up with
| Provider | Payment Methods | Description |
|----------|----------------|-------------|
| **EasyPay** | Alipay, WeChat Pay | Third-party aggregation via EasyPay protocol |
-| **Alipay (Direct)** | PC Page Pay, H5 Mobile Pay | Direct integration with Alipay Open Platform, auto-switches by device |
-| **WeChat Pay (Direct)** | Native QR Code, H5 Pay | Direct integration with WeChat Pay APIv3, mobile-first H5 |
+| **Alipay (Direct)** | Desktop QR code, mobile Alipay redirect | Direct integration with Alipay Open Platform, returning desktop QR codes and mobile WAP/app launch links |
+| **WeChat Pay (Direct)** | Native QR, H5, MP/JSAPI Pay | Direct integration with WeChat Pay APIv3 with environment-aware routing |
| **Stripe** | Card, Alipay, WeChat Pay, Link, etc. | International payments, multi-currency support |
-> Alipay/WeChat Pay direct and EasyPay can coexist. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup.
+> Alipay/WeChat Pay direct and EasyPay can both exist as backend provider instances, but the frontend always exposes only two visible buttons: `Alipay` and `WeChat Pay`. Admins choose exactly one source for each visible method: direct or EasyPay. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup.
> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need:
>
@@ -61,9 +61,18 @@ Configure the following in Admin Dashboard **Settings → Payment Settings**:
| **Minimum Amount** | Minimum single top-up amount | 1 |
| **Maximum Amount** | Maximum single top-up amount (empty = unlimited) | - |
| **Daily Limit** | Per-user daily cumulative limit (empty = unlimited) | - |
-| **Order Timeout** | Order timeout in minutes (minimum 1) | 5 |
+| **Order Timeout** | Order timeout in minutes (minimum 1) | 30 |
| **Max Pending Orders** | Maximum concurrent pending orders per user | 3 |
-| **Load Balance Strategy** | Strategy for selecting provider instances | Least Amount |
+| **Load Balance Strategy** | Strategy for selecting provider instances | Round Robin |
+
+### Frontend Visible Method Routing
+
+The current payment UX keeps the frontend method list unified and does not expose provider brands directly:
+
+- **Alipay**: when enabled, this button must be routed to either `Alipay (Direct)` or `EasyPay Alipay`
+- **WeChat Pay**: when enabled, this button must be routed to either `WeChat Pay (Direct)` or `EasyPay WeChat`
+- Each visible method can route to only one source at a time
+- If a visible method is enabled without a selected source, the frontend will not expose that method
### Load Balance Strategies
@@ -113,7 +122,7 @@ Compatible with any payment service that implements the EasyPay protocol.
### Alipay (Direct)
-Direct integration with Alipay Open Platform. Supports PC page pay and H5 mobile pay.
+Direct integration with Alipay Open Platform. Mobile flows return an Alipay WAP/app redirect URL. Desktop flows prefer Face-to-Face Precreate QR payloads; if the merchant has not enabled that product, the provider falls back to Computer Website Pay and also returns the cashier URL so the frontend can render a QR code or open the hosted checkout page directly.
| Parameter | Description | Required |
|-----------|-------------|----------|
@@ -123,7 +132,7 @@ Direct integration with Alipay Open Platform. Supports PC page pay and H5 mobile
### WeChat Pay (Direct)
-Direct integration with WeChat Pay APIv3. Supports Native QR code and H5 payment.
+Direct integration with WeChat Pay APIv3. Supports Native QR code payment, H5 payment, and MP/JSAPI payment inside the WeChat environment.
| Parameter | Description | Required |
|-----------|-------------|----------|
@@ -132,8 +141,8 @@ Direct integration with WeChat Pay APIv3. Supports Native QR code and H5 payment
| **Merchant API Private Key** | Merchant API private key (PEM format) | Yes |
| **APIv3 Key** | 32-byte APIv3 key | Yes |
| **WeChat Pay Public Key** | WeChat Pay public key (PEM format) | Yes |
-| **WeChat Pay Public Key ID** | WeChat Pay public key ID | No |
-| **Certificate Serial Number** | Merchant certificate serial number | No |
+| **WeChat Pay Public Key ID** | WeChat Pay public key ID | Yes |
+| **Certificate Serial Number** | Merchant certificate serial number | Yes |
### Stripe
@@ -220,8 +229,8 @@ User selects amount and payment method
▼
User completes payment
├─ EasyPay → QR code / H5 redirect
- ├─ Alipay → PC page pay / H5 mobile pay
- ├─ WeChat Pay → Native QR / H5 pay
+ ├─ Alipay → Desktop QR payload (Face-to-Face preferred, Website Pay fallback) / mobile Alipay redirect
+ ├─ WeChat Pay → Desktop Native QR / non-WeChat H5 / in-WeChat JSAPI
└─ Stripe → Payment Element (card/Alipay/WeChat/etc.)
│
▼
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
index aca3c866..ae765fb9 100644
--- a/docs/PAYMENT_CN.md
+++ b/docs/PAYMENT_CN.md
@@ -22,11 +22,11 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| 服务商 | 支付方式 | 说明 |
|--------|---------|------|
| **EasyPay(易支付)** | 支付宝、微信支付 | 兼容易支付协议的第三方聚合支付 |
-| **支付宝官方** | 支付宝 PC 页面支付、H5 手机网站支付 | 直接对接支付宝开放平台,自动根据终端切换 |
-| **微信官方** | Native 扫码支付、H5 支付 | 直接对接微信支付 APIv3,移动端优先 H5 |
+| **支付宝官方** | 桌面二维码扫码、移动端支付宝跳转 | 直接对接支付宝开放平台,桌面端返回二维码,移动端返回 WAP/唤起链接 |
+| **微信官方** | Native 扫码、H5、公众号/JSAPI 支付 | 直接对接微信支付 APIv3,按终端环境自动分流 |
| **Stripe** | 银行卡、支付宝、微信支付、Link 等 | 国际支付,支持多币种 |
-> 支付宝官方 / 微信官方与易支付可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。
+> 支付宝官方 / 微信官方与易支付可以同时作为后台服务商实例存在,但前台始终只展示 `支付宝`、`微信支付` 两个可见按钮。管理员需要分别为这两个按钮选择唯一支付来源:官方或易支付。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。
> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择:
>
@@ -61,9 +61,18 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| **最低金额** | 单笔最低充值金额 | 1 |
| **最高金额** | 单笔最高充值金额(留空表示不限制) | - |
| **每日限额** | 每用户每日累计充值上限(留空表示不限制) | - |
-| **订单超时时间** | 订单超时分钟数,至少 1 分钟 | 5 |
+| **订单超时时间** | 订单超时分钟数,至少 1 分钟 | 30 |
| **最大待支付订单数** | 同一用户最大并行待支付订单数 | 3 |
-| **负载均衡策略** | 多服务商实例时的选择策略 | 最少金额 |
+| **负载均衡策略** | 多服务商实例时的选择策略 | 轮询 |
+
+### 前台可见支付方式路由
+
+当前版本对用户统一展示支付方式,不区分官方渠道还是易支付:
+
+- **支付宝**:后台启用后,需要额外指定该按钮路由到 `支付宝官方` 或 `易支付支付宝`
+- **微信支付**:后台启用后,需要额外指定该按钮路由到 `微信官方` 或 `易支付微信`
+- 同一个可见支付方式在同一时刻只能路由到一个来源
+- 支付来源未选择时,即使对应按钮被开启,前台也不会暴露该支付方式
### 负载均衡策略
@@ -113,7 +122,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
### 支付宝官方
-直接对接支付宝开放平台,支持 PC 页面支付和 H5 手机网站支付。
+直接对接支付宝开放平台。移动端走支付宝手机网站支付跳转;桌面端优先使用当面付返回扫码串,若商户未开通当面付则回退到电脑网站支付,并将收银台链接同时返回给前端用于渲染二维码或直接打开支付页。
| 参数 | 说明 | 必填 |
|------|------|------|
@@ -123,7 +132,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
### 微信官方
-直接对接微信支付 APIv3,支持 Native 扫码支付和 H5 支付。
+直接对接微信支付 APIv3,支持 Native 扫码支付、H5 支付,以及在微信环境内的公众号/JSAPI 支付。
| 参数 | 说明 | 必填 |
|------|------|------|
@@ -132,8 +141,8 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
| **商户 API 私钥** | 商户 API 私钥(PEM 格式) | 是 |
| **APIv3 密钥** | 32 位 APIv3 密钥 | 是 |
| **微信支付公钥** | 微信支付公钥(PEM 格式) | 是 |
-| **微信支付公钥 ID** | 微信支付公钥 ID | 否 |
-| **商户证书序列号** | 商户证书序列号 | 否 |
+| **微信支付公钥 ID** | 微信支付公钥 ID | 是 |
+| **商户证书序列号** | 商户证书序列号 | 是 |
### Stripe
@@ -220,8 +229,8 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支
▼
用户完成支付
├─ EasyPay → 扫码 / H5 跳转
- ├─ 支付宝官方 → PC 页面支付 / H5 手机网站支付
- ├─ 微信官方 → Native 扫码 / H5 支付
+ ├─ 支付宝官方 → 桌面扫码单(当面付优先,电脑网站支付回退)/ 移动端支付宝跳转
+ ├─ 微信官方 → 桌面 Native 扫码 / 非微信 H5 / 微信内 JSAPI
└─ Stripe → Payment Element(银行卡/支付宝/微信等)
│
▼
diff --git a/frontend/src/api/__tests__/admin.users.spec.ts b/frontend/src/api/__tests__/admin.users.spec.ts
new file mode 100644
index 00000000..37656b78
--- /dev/null
+++ b/frontend/src/api/__tests__/admin.users.spec.ts
@@ -0,0 +1,117 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const { post } = vi.hoisted(() => ({
+ post: vi.fn(),
+}))
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ post,
+ },
+}))
+
+import {
+ bindUserAuthIdentity,
+ type AdminBindAuthIdentityRequest,
+ type AdminBoundAuthIdentity,
+} from '@/api/admin/users'
+
+type Assert = T
+type IsExact = (
+ (() => G extends T ? 1 : 2) extends (() => G extends U ? 1 : 2)
+ ? ((() => G extends U ? 1 : 2) extends (() => G extends T ? 1 : 2) ? true : false)
+ : false
+)
+
+type ExpectedAdminBindAuthIdentityRequest = {
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ issuer?: string
+ metadata?: Record
+ channel?: {
+ channel: string
+ channel_app_id: string
+ channel_subject: string
+ metadata?: Record
+ }
+}
+
+type ExpectedAdminBoundAuthIdentity = {
+ user_id: number
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ verified_at?: string | null
+ issuer?: string | null
+ metadata: Record | null
+ created_at: string
+ updated_at: string
+ channel?: {
+ channel: string
+ channel_app_id: string
+ channel_subject: string
+ metadata: Record | null
+ created_at: string
+ updated_at: string
+ } | null
+}
+
+const requestContractExact: Assert<
+ IsExact
+> = true
+const responseContractExact: Assert<
+ IsExact
+> = true
+
+describe('admin users api auth identity binding', () => {
+ beforeEach(() => {
+ post.mockReset()
+ })
+
+ it('posts the backend-compatible auth identity bind payload and returns the backend response shape', async () => {
+ const payload: AdminBindAuthIdentityRequest = {
+ provider_type: 'wechat',
+ provider_key: 'wechat-main',
+ provider_subject: 'union-123',
+ metadata: { source: 'admin-repair' },
+ channel: {
+ channel: 'open',
+ channel_app_id: 'wx-open',
+ channel_subject: 'openid-123',
+ metadata: { scene: 'migration' },
+ },
+ }
+
+ const response: AdminBoundAuthIdentity = {
+ user_id: 9,
+ provider_type: 'wechat',
+ provider_key: 'wechat-main',
+ provider_subject: 'union-123',
+ verified_at: '2026-04-22T00:00:00Z',
+ issuer: null,
+ metadata: { source: 'admin-repair' },
+ created_at: '2026-04-22T00:00:00Z',
+ updated_at: '2026-04-22T00:00:00Z',
+ channel: {
+ channel: 'open',
+ channel_app_id: 'wx-open',
+ channel_subject: 'openid-123',
+ metadata: { scene: 'migration' },
+ created_at: '2026-04-22T00:00:00Z',
+ updated_at: '2026-04-22T00:00:00Z',
+ },
+ }
+ post.mockResolvedValue({ data: response })
+
+ const result = await bindUserAuthIdentity(9, payload)
+
+ expect(post).toHaveBeenCalledWith('/admin/users/9/auth-identities', payload)
+ expect(result).toEqual(response)
+ })
+
+ it('keeps bind auth identity request and response types aligned with the backend contract', () => {
+ expect(requestContractExact).toBe(true)
+ expect(responseContractExact).toBe(true)
+ })
+})
diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
new file mode 100644
index 00000000..07a68c03
--- /dev/null
+++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
@@ -0,0 +1,224 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const post = vi.fn()
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ post
+ }
+}))
+
+describe('oauth adoption auth api', () => {
+ beforeEach(() => {
+ post.mockReset()
+ post.mockResolvedValue({ data: {} })
+ localStorage.clear()
+ document.cookie = 'oauth_bind_access_token=; Max-Age=0; path=/'
+ })
+
+ it('posts adoption decisions when exchanging pending oauth completion', async () => {
+ const { exchangePendingOAuthCompletion } = await import('@/api/auth')
+
+ await exchangePendingOAuthCompletion({
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts bind-login decisions when finalizing pending oauth bind flow', async () => {
+ const { completePendingOAuthBindLogin } = await import('@/api/auth')
+
+ await completePendingOAuthBindLogin({
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts linuxdo invitation completion with adoption decisions', async () => {
+ const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
+
+ await completeLinuxDoOAuthRegistration('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts linuxdo create-account completion with adoption decisions', async () => {
+ const { createPendingLinuxDoOAuthAccount } = await import('@/api/auth')
+
+ await createPendingLinuxDoOAuthAccount('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts affiliate code when completing linuxdo oauth registration', async () => {
+ const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
+
+ await completeLinuxDoOAuthRegistration(
+ 'invite-code',
+ {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ },
+ ' AFF123 '
+ )
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ aff_code: 'AFF123',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts oidc invitation completion with adoption decisions', async () => {
+ const { completeOIDCOAuthRegistration } = await import('@/api/auth')
+
+ await completeOIDCOAuthRegistration('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts oidc create-account completion with adoption decisions', async () => {
+ const { createPendingOIDCOAuthAccount } = await import('@/api/auth')
+
+ await createPendingOIDCOAuthAccount('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts wechat invitation completion with adoption decisions', async () => {
+ const { completeWeChatOAuthRegistration } = await import('@/api/auth')
+
+ await completeWeChatOAuthRegistration('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts wechat create-account completion with adoption decisions', async () => {
+ const { createPendingWeChatOAuthAccount } = await import('@/api/auth')
+
+ await createPendingWeChatOAuthAccount('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts affiliate code when creating pending wechat oauth account', async () => {
+ const { createPendingWeChatOAuthAccount } = await import('@/api/auth')
+
+ await createPendingWeChatOAuthAccount(
+ 'invite-code',
+ {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ },
+ 'WXAFF'
+ )
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'invite-code',
+ aff_code: 'WXAFF',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('classifies oauth completion results as login or bind', async () => {
+ const { getOAuthCompletionKind } = await import('@/api/auth')
+
+ expect(getOAuthCompletionKind({ access_token: 'access-token' })).toBe('login')
+ expect(getOAuthCompletionKind({ redirect: '/profile' })).toBe('bind')
+ })
+
+ it('provides bind-login utility helpers for invitation and suggested profile states', async () => {
+ const {
+ getPendingOAuthBindLoginKind,
+ hasPendingOAuthSuggestedProfile,
+ isPendingOAuthCreateAccountRequired
+ } = await import('@/api/auth')
+
+ expect(getPendingOAuthBindLoginKind({ access_token: 'access-token' })).toBe('login')
+ expect(getPendingOAuthBindLoginKind({ redirect: '/profile' })).toBe('bind')
+ expect(
+ isPendingOAuthCreateAccountRequired({
+ error: 'invitation_required'
+ })
+ ).toBe(true)
+ expect(
+ isPendingOAuthCreateAccountRequired({
+ error: 'other'
+ })
+ ).toBe(false)
+ expect(
+ hasPendingOAuthSuggestedProfile({
+ suggested_display_name: 'OAuth Nick'
+ })
+ ).toBe(true)
+ expect(
+ hasPendingOAuthSuggestedProfile({
+ suggested_avatar_url: 'https://cdn.example/avatar.png'
+ })
+ ).toBe(true)
+ expect(hasPendingOAuthSuggestedProfile({})).toBe(false)
+ })
+
+ it('requests an HttpOnly oauth bind cookie before redirect binding', async () => {
+ localStorage.setItem('auth_token', 'access-token-value')
+ const { prepareOAuthBindAccessTokenCookie } = await import('@/api/auth')
+
+ await prepareOAuthBindAccessTokenCookie()
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/bind-token')
+ })
+})
diff --git a/frontend/src/api/__tests__/client.spec.ts b/frontend/src/api/__tests__/client.spec.ts
index 0f663e76..a46c39eb 100644
--- a/frontend/src/api/__tests__/client.spec.ts
+++ b/frontend/src/api/__tests__/client.spec.ts
@@ -91,6 +91,22 @@ describe('API Client', () => {
const config = adapter.mock.calls[0][0]
expect(config.params?.timezone).toBeUndefined()
})
+
+ it('请求默认带 withCredentials 以支持跨域 cookie', async () => {
+ const adapter = vi.fn().mockResolvedValue({
+ status: 200,
+ data: { code: 0, data: {} },
+ headers: {},
+ config: {},
+ statusText: 'OK',
+ })
+ apiClient.defaults.adapter = adapter
+
+ await apiClient.post('/auth/oauth/bind-token')
+
+ const config = adapter.mock.calls[0][0]
+ expect(config.withCredentials).toBe(true)
+ })
})
// --- 响应拦截器 ---
diff --git a/frontend/src/api/__tests__/payment.spec.ts b/frontend/src/api/__tests__/payment.spec.ts
new file mode 100644
index 00000000..e38fba57
--- /dev/null
+++ b/frontend/src/api/__tests__/payment.spec.ts
@@ -0,0 +1,40 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const { get, post } = vi.hoisted(() => ({
+ get: vi.fn(),
+ post: vi.fn(),
+}))
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ get,
+ post,
+ },
+}))
+
+import { paymentAPI } from '@/api/payment'
+
+describe('payment api', () => {
+ beforeEach(() => {
+ get.mockReset()
+ post.mockReset()
+ get.mockResolvedValue({ data: {} })
+ post.mockResolvedValue({ data: {} })
+ })
+
+ it('keeps legacy public out_trade_no verification for upgrade compatibility', async () => {
+ await paymentAPI.verifyOrderPublic('legacy-order-no')
+
+ expect(post).toHaveBeenCalledWith('/payment/public/orders/verify', {
+ out_trade_no: 'legacy-order-no',
+ })
+ })
+
+ it('keeps signed public resume-token resolve endpoint', async () => {
+ await paymentAPI.resolveOrderPublicByResumeToken('resume-token-123')
+
+ expect(post).toHaveBeenCalledWith('/payment/public/orders/resolve', {
+ resume_token: 'resume-token-123',
+ })
+ })
+})
diff --git a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts
new file mode 100644
index 00000000..10f6247a
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts
@@ -0,0 +1,131 @@
+import { describe, expect, it } from "vitest";
+
+import {
+ appendAuthSourceDefaultsToUpdateRequest,
+ buildAuthSourceDefaultsState,
+ type UpdateSettingsRequest,
+} from "@/api/admin/settings";
+
+describe("admin settings auth source defaults helpers", () => {
+ it("builds auth source defaults state from flat settings fields", () => {
+ const state = buildAuthSourceDefaultsState({
+ auth_source_default_email_balance: 9.5,
+ auth_source_default_email_concurrency: 3,
+ auth_source_default_email_subscriptions: [
+ { group_id: 1, validity_days: 30 },
+ ],
+ auth_source_default_email_grant_on_signup: false,
+ auth_source_default_email_grant_on_first_bind: true,
+ auth_source_default_linuxdo_balance: 6,
+ auth_source_default_linuxdo_concurrency: 8,
+ auth_source_default_linuxdo_subscriptions: [
+ { group_id: 2, validity_days: 60 },
+ ],
+ auth_source_default_linuxdo_grant_on_signup: true,
+ auth_source_default_linuxdo_grant_on_first_bind: false,
+ });
+
+ expect(state.email).toEqual({
+ balance: 9.5,
+ concurrency: 3,
+ subscriptions: [{ group_id: 1, validity_days: 30 }],
+ grant_on_signup: false,
+ grant_on_first_bind: true,
+ });
+ expect(state.linuxdo).toEqual({
+ balance: 6,
+ concurrency: 8,
+ subscriptions: [{ group_id: 2, validity_days: 60 }],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ });
+ expect(state.oidc).toEqual({
+ balance: 0,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: false,
+ });
+ expect(state.wechat).toEqual({
+ balance: 0,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: false,
+ });
+ });
+
+ it("defaults grant-on-signup to disabled when settings are missing", () => {
+ const state = buildAuthSourceDefaultsState({});
+
+ expect(state.email.grant_on_signup).toBe(false);
+ expect(state.linuxdo.grant_on_signup).toBe(false);
+ expect(state.oidc.grant_on_signup).toBe(false);
+ expect(state.wechat.grant_on_signup).toBe(false);
+ });
+
+ it("appends auth source defaults back onto update payload", () => {
+ const payload: UpdateSettingsRequest = {
+ site_name: "Sub2API",
+ };
+
+ appendAuthSourceDefaultsToUpdateRequest(payload, {
+ email: {
+ balance: 1.25,
+ concurrency: 2,
+ subscriptions: [{ group_id: 3, validity_days: 7 }],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ },
+ linuxdo: {
+ balance: 0,
+ concurrency: 6,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: true,
+ },
+ oidc: {
+ balance: 4,
+ concurrency: 9,
+ subscriptions: [{ group_id: 9, validity_days: 90 }],
+ grant_on_signup: true,
+ grant_on_first_bind: true,
+ },
+ wechat: {
+ balance: 2,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: false,
+ },
+ });
+
+ expect(payload).toMatchObject({
+ site_name: "Sub2API",
+ auth_source_default_email_balance: 1.25,
+ auth_source_default_email_concurrency: 2,
+ auth_source_default_email_subscriptions: [
+ { group_id: 3, validity_days: 7 },
+ ],
+ auth_source_default_email_grant_on_signup: true,
+ auth_source_default_email_grant_on_first_bind: false,
+ auth_source_default_linuxdo_balance: 0,
+ auth_source_default_linuxdo_concurrency: 6,
+ auth_source_default_linuxdo_subscriptions: [],
+ auth_source_default_linuxdo_grant_on_signup: false,
+ auth_source_default_linuxdo_grant_on_first_bind: true,
+ auth_source_default_oidc_balance: 4,
+ auth_source_default_oidc_concurrency: 9,
+ auth_source_default_oidc_subscriptions: [
+ { group_id: 9, validity_days: 90 },
+ ],
+ auth_source_default_oidc_grant_on_signup: true,
+ auth_source_default_oidc_grant_on_first_bind: true,
+ auth_source_default_wechat_balance: 2,
+ auth_source_default_wechat_concurrency: 5,
+ auth_source_default_wechat_subscriptions: [],
+ auth_source_default_wechat_grant_on_signup: false,
+ auth_source_default_wechat_grant_on_first_bind: false,
+ });
+ });
+});
diff --git a/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
new file mode 100644
index 00000000..ad355afe
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
@@ -0,0 +1,63 @@
+import { describe, expect, it } from 'vitest'
+
+import {
+ getPaymentVisibleMethodSourceOptions,
+ normalizePaymentVisibleMethodSource,
+} from '@/api/admin/settings'
+
+describe('admin settings payment visible method helpers', () => {
+ it('normalizes aliases into canonical source keys per visible method', () => {
+ expect(normalizePaymentVisibleMethodSource('alipay', 'official')).toBe('official_alipay')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'alipay_direct')).toBe('official_alipay')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'easypay')).toBe('easypay_alipay')
+
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'official')).toBe('official_wxpay')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'wechat')).toBe('official_wxpay')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'easypay')).toBe('easypay_wxpay')
+ })
+
+ it('rejects unknown or cross-method source values', () => {
+ expect(normalizePaymentVisibleMethodSource('alipay', 'official_wxpay')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'official_alipay')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'unknown')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('wxpay', null)).toBe('')
+ })
+
+ it('exposes method-scoped source options instead of arbitrary strings', () => {
+ expect(getPaymentVisibleMethodSourceOptions('alipay')).toEqual([
+ {
+ value: '',
+ labelZh: '未配置',
+ labelEn: 'Not configured',
+ },
+ {
+ value: 'official_alipay',
+ labelZh: '支付宝官方',
+ labelEn: 'Official Alipay',
+ },
+ {
+ value: 'easypay_alipay',
+ labelZh: '易支付支付宝',
+ labelEn: 'EasyPay Alipay',
+ },
+ ])
+
+ expect(getPaymentVisibleMethodSourceOptions('wxpay')).toEqual([
+ {
+ value: '',
+ labelZh: '未配置',
+ labelEn: 'Not configured',
+ },
+ {
+ value: 'official_wxpay',
+ labelZh: '微信官方',
+ labelEn: 'Official WeChat Pay',
+ },
+ {
+ value: 'easypay_wxpay',
+ labelZh: '易支付微信',
+ labelEn: 'EasyPay WeChat Pay',
+ },
+ ])
+ })
+})
diff --git a/frontend/src/api/__tests__/settings.wechatConnect.spec.ts b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts
new file mode 100644
index 00000000..eccb7214
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts
@@ -0,0 +1,21 @@
+import { describe, expect, it } from "vitest";
+
+import {
+ defaultWeChatConnectScopesForMode,
+ normalizeWeChatConnectMode,
+} from "@/api/admin/settings";
+
+describe("admin settings wechat connect helpers", () => {
+ it("normalizes legacy or noisy mode values to the backend contract", () => {
+ expect(normalizeWeChatConnectMode("OPEN")).toBe("open");
+ expect(normalizeWeChatConnectMode(" open_platform ")).toBe("open");
+ expect(normalizeWeChatConnectMode("mp")).toBe("mp");
+ expect(normalizeWeChatConnectMode("official_account")).toBe("mp");
+ expect(normalizeWeChatConnectMode("unknown")).toBe("open");
+ });
+
+ it("maps each mode to the backend default scopes", () => {
+ expect(defaultWeChatConnectScopesForMode("open")).toBe("snsapi_login");
+ expect(defaultWeChatConnectScopesForMode("mp")).toBe("snsapi_userinfo");
+ });
+});
diff --git a/frontend/src/api/__tests__/user.spec.ts b/frontend/src/api/__tests__/user.spec.ts
new file mode 100644
index 00000000..887046da
--- /dev/null
+++ b/frontend/src/api/__tests__/user.spec.ts
@@ -0,0 +1,32 @@
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+
+describe('user api oauth binding urls', () => {
+ beforeEach(() => {
+ vi.resetModules()
+ vi.stubEnv('VITE_API_BASE_URL', 'https://api.example.com/api/v1')
+ })
+
+ afterEach(() => {
+ vi.unstubAllEnvs()
+ })
+
+ it('builds third-party bind urls against the bind start endpoint', async () => {
+ const { buildOAuthBindingStartURL } = await import('@/api/user')
+
+ expect(buildOAuthBindingStartURL('linuxdo', { redirectTo: '/settings/profile' })).toBe(
+ 'https://api.example.com/api/v1/auth/oauth/linuxdo/bind/start?redirect=%2Fsettings%2Fprofile&intent=bind_current_user'
+ )
+ expect(
+ buildOAuthBindingStartURL('wechat', {
+ redirectTo: '/settings/profile',
+ wechatOAuthSettings: {
+ wechat_oauth_open_enabled: true,
+ wechat_oauth_mp_enabled: false,
+ wechat_oauth_mobile_enabled: false
+ }
+ })
+ ).toBe(
+ 'https://api.example.com/api/v1/auth/oauth/wechat/bind/start?redirect=%2Fsettings%2Fprofile&intent=bind_current_user&mode=open'
+ )
+ })
+})
diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts
index a146f1f7..8a127793 100644
--- a/frontend/src/api/admin/accounts.ts
+++ b/frontend/src/api/admin/accounts.ts
@@ -370,8 +370,8 @@ export async function batchUpdateCredentials(request: {
* @returns Success confirmation
*/
export async function bulkUpdate(
- accountIds: number[],
- updates: Record
+ accountIdsOrPayload: number[] | Record,
+ updates?: Record
): Promise<{
success: number
failed: number
@@ -379,16 +379,19 @@ export async function bulkUpdate(
failed_ids?: number[]
results: Array<{ account_id: number; success: boolean; error?: string }>
}> {
+ const payload = Array.isArray(accountIdsOrPayload)
+ ? {
+ account_ids: accountIdsOrPayload,
+ ...(updates ?? {})
+ }
+ : accountIdsOrPayload
const { data } = await apiClient.post<{
success: number
failed: number
success_ids?: number[]
failed_ids?: number[]
results: Array<{ account_id: number; success: boolean; error?: string }>
- }>('/admin/accounts/bulk-update', {
- account_ids: accountIds,
- ...updates
- })
+ }>('/admin/accounts/bulk-update', payload)
return data
}
diff --git a/frontend/src/api/admin/affiliates.ts b/frontend/src/api/admin/affiliates.ts
new file mode 100644
index 00000000..22639bd2
--- /dev/null
+++ b/frontend/src/api/admin/affiliates.ts
@@ -0,0 +1,108 @@
+/**
+ * Admin Affiliate API endpoints
+ * Manage per-user affiliate (邀请返利) configurations:
+ * exclusive invite codes (overrides aff_code) and exclusive rebate rates.
+ */
+
+import { apiClient } from '../client'
+import type { PaginatedResponse } from '@/types'
+
+export interface AffiliateAdminEntry {
+ user_id: number
+ email: string
+ username: string
+ aff_code: string
+ aff_code_custom: boolean
+ aff_rebate_rate_percent?: number | null
+ aff_count: number
+}
+
+export interface ListAffiliateUsersParams {
+ page?: number
+ page_size?: number
+ search?: string
+}
+
+export interface UpdateAffiliateUserRequest {
+ aff_code?: string
+ aff_rebate_rate_percent?: number | null
+ /** Set true to explicitly clear the per-user rate (sets it to NULL). */
+ clear_rebate_rate?: boolean
+}
+
+export interface BatchSetRateRequest {
+ user_ids: number[]
+ aff_rebate_rate_percent?: number | null
+ /** Set true to clear rates instead of setting. */
+ clear?: boolean
+}
+
+export interface SimpleUser {
+ id: number
+ email: string
+ username: string
+}
+
+export async function listUsers(
+ params: ListAffiliateUsersParams = {},
+): Promise> {
+ const { data } = await apiClient.get>(
+ '/admin/affiliates/users',
+ {
+ params: {
+ page: params.page ?? 1,
+ page_size: params.page_size ?? 20,
+ search: params.search ?? '',
+ },
+ },
+ )
+ return data
+}
+
+export async function lookupUsers(q: string): Promise {
+ const { data } = await apiClient.get(
+ '/admin/affiliates/users/lookup',
+ { params: { q } },
+ )
+ return data
+}
+
+export async function updateUserSettings(
+ userId: number,
+ payload: UpdateAffiliateUserRequest,
+): Promise<{ user_id: number }> {
+ const { data } = await apiClient.put<{ user_id: number }>(
+ `/admin/affiliates/users/${userId}`,
+ payload,
+ )
+ return data
+}
+
+export async function clearUserSettings(
+ userId: number,
+): Promise<{ user_id: number }> {
+ const { data } = await apiClient.delete<{ user_id: number }>(
+ `/admin/affiliates/users/${userId}`,
+ )
+ return data
+}
+
+export async function batchSetRate(
+ payload: BatchSetRateRequest,
+): Promise<{ affected: number }> {
+ const { data } = await apiClient.post<{ affected: number }>(
+ '/admin/affiliates/users/batch-rate',
+ payload,
+ )
+ return data
+}
+
+export const affiliatesAPI = {
+ listUsers,
+ lookupUsers,
+ updateUserSettings,
+ clearUserSettings,
+ batchSetRate,
+}
+
+export default affiliatesAPI
diff --git a/frontend/src/api/admin/channelMonitor.ts b/frontend/src/api/admin/channelMonitor.ts
new file mode 100644
index 00000000..949c4bc8
--- /dev/null
+++ b/frontend/src/api/admin/channelMonitor.ts
@@ -0,0 +1,202 @@
+/**
+ * Admin Channel Monitor API endpoints
+ * Handles channel monitor (uptime/health) management for administrators
+ */
+
+import { apiClient } from '../client'
+
+export type Provider = 'openai' | 'anthropic' | 'gemini'
+export type MonitorStatus = 'operational' | 'degraded' | 'failed' | 'error'
+export type BodyOverrideMode = 'off' | 'merge' | 'replace'
+
+export interface ChannelMonitor {
+ id: number
+ name: string
+ provider: Provider
+ endpoint: string
+ api_key_masked: string
+ /**
+ * True when the stored encrypted API key cannot be decrypted (e.g. the
+ * encryption key has changed). Admin must re-edit the monitor to provide
+ * a fresh key. Backend skips checks for these monitors.
+ */
+ api_key_decrypt_failed?: boolean
+ primary_model: string
+ extra_models: string[]
+ group_name: string
+ enabled: boolean
+ interval_seconds: number
+ last_checked_at: string | null
+ created_by: number
+ created_at: string
+ updated_at: string
+ /** Latest status of the primary model (empty when no history yet) */
+ primary_status: MonitorStatus | ''
+ /** Latest latency of the primary model in ms (null when no history yet) */
+ primary_latency_ms: number | null
+ /** Primary model 7-day availability percentage (0-100) */
+ availability_7d: number
+ /** Latest status per extra model (used for hover tooltip) */
+ extra_models_status: ExtraModelStatus[]
+ /** 请求自定义快照字段(高级设置) */
+ template_id: number | null
+ extra_headers: Record
+ body_override_mode: BodyOverrideMode
+ body_override: Record | null
+}
+
+export interface ExtraModelStatus {
+ model: string
+ status: MonitorStatus | ''
+ latency_ms: number | null
+}
+
+export interface ListParams {
+ page?: number
+ page_size?: number
+ provider?: Provider
+ enabled?: boolean
+ search?: string
+}
+
+export interface ListResponse {
+ items: ChannelMonitor[]
+ total: number
+ page: number
+ page_size: number
+ pages: number
+}
+
+export interface CreateParams {
+ name: string
+ provider: Provider
+ endpoint: string
+ api_key: string
+ primary_model: string
+ extra_models?: string[]
+ group_name?: string
+ enabled?: boolean
+ interval_seconds: number
+ template_id?: number | null
+ extra_headers?: Record
+ body_override_mode?: BodyOverrideMode
+ body_override?: Record | null
+}
+
+// Update request: api_key 空串 = 不修改;clear_template=true 时把 template_id 置空
+export type UpdateParams = Partial & {
+ clear_template?: boolean
+}
+
+export interface CheckResult {
+ model: string
+ status: MonitorStatus
+ latency_ms: number | null
+ ping_latency_ms: number | null
+ message: string
+ checked_at: string
+}
+
+export interface RunNowResponse {
+ results: CheckResult[]
+}
+
+export interface HistoryItem {
+ id: number
+ model: string
+ status: MonitorStatus
+ latency_ms: number | null
+ ping_latency_ms: number | null
+ message: string
+ checked_at: string
+}
+
+export interface HistoryParams {
+ model?: string
+ limit?: number
+}
+
+export interface HistoryResponse {
+ items: HistoryItem[]
+}
+
+/**
+ * List channel monitors with pagination and filters
+ */
+export async function list(
+ params: ListParams = {},
+ options?: { signal?: AbortSignal }
+): Promise {
+ const { data } = await apiClient.get('/admin/channel-monitors', {
+ params,
+ signal: options?.signal,
+ })
+ return data
+}
+
+/**
+ * Get a channel monitor by ID
+ */
+export async function get(id: number): Promise {
+ const { data } = await apiClient.get(`/admin/channel-monitors/${id}`)
+ return data
+}
+
+/**
+ * Create a new channel monitor
+ */
+export async function create(params: CreateParams): Promise {
+ const { data } = await apiClient.post('/admin/channel-monitors', params)
+ return data
+}
+
+/**
+ * Update an existing channel monitor.
+ * api_key field: empty string means "do not modify".
+ */
+export async function update(id: number, params: UpdateParams): Promise {
+ const { data } = await apiClient.put(`/admin/channel-monitors/${id}`, params)
+ return data
+}
+
+/**
+ * Delete a channel monitor
+ */
+export async function del(id: number): Promise {
+ await apiClient.delete(`/admin/channel-monitors/${id}`)
+}
+
+/**
+ * Trigger an immediate manual check for a channel monitor.
+ * Returns the latest check results for primary + extra models.
+ */
+export async function runNow(id: number): Promise {
+ const { data } = await apiClient.post(`/admin/channel-monitors/${id}/run`)
+ return data
+}
+
+/**
+ * List historical check results for a monitor.
+ */
+export async function listHistory(
+ id: number,
+ params: HistoryParams = {}
+): Promise {
+ const { data } = await apiClient.get(
+ `/admin/channel-monitors/${id}/history`,
+ { params }
+ )
+ return data
+}
+
+export const channelMonitorAPI = {
+ list,
+ get,
+ create,
+ update,
+ del,
+ runNow,
+ listHistory,
+}
+
+export default channelMonitorAPI
diff --git a/frontend/src/api/admin/channelMonitorTemplate.ts b/frontend/src/api/admin/channelMonitorTemplate.ts
new file mode 100644
index 00000000..01b3c2d0
--- /dev/null
+++ b/frontend/src/api/admin/channelMonitorTemplate.ts
@@ -0,0 +1,132 @@
+/**
+ * Admin Channel Monitor Request Template API.
+ *
+ * 模板 = 一组可复用的 headers + 可选 body 覆盖配置。
+ * 应用到监控 = 拷贝快照;模板后续变动不自动同步,需手动点「应用到关联监控」刷新。
+ */
+
+import { apiClient } from '../client'
+import type { BodyOverrideMode, Provider } from './channelMonitor'
+
+export interface ChannelMonitorTemplate {
+ id: number
+ name: string
+ provider: Provider
+ description: string
+ extra_headers: Record
+ body_override_mode: BodyOverrideMode
+ body_override: Record | null
+ created_at: string
+ updated_at: string
+ /** 关联的监控数量(快照来自此模板,仅 template_id 匹配即可) */
+ associated_monitors: number
+}
+
+export interface ListParams {
+ provider?: Provider
+}
+
+export interface ListResponse {
+ items: ChannelMonitorTemplate[]
+}
+
+export interface CreateParams {
+ name: string
+ provider: Provider
+ description?: string
+ extra_headers?: Record
+ body_override_mode?: BodyOverrideMode
+ body_override?: Record | null
+}
+
+export interface UpdateParams {
+ name?: string
+ description?: string
+ extra_headers?: Record
+ body_override_mode?: BodyOverrideMode
+ body_override?: Record | null
+}
+
+export interface ApplyResponse {
+ affected: number
+}
+
+export interface AssociatedMonitorBrief {
+ id: number
+ name: string
+ provider: Provider
+ enabled: boolean
+}
+
+export interface AssociatedMonitorsResponse {
+ items: AssociatedMonitorBrief[]
+}
+
+export async function list(params: ListParams = {}): Promise {
+ const { data } = await apiClient.get('/admin/channel-monitor-templates', {
+ params,
+ })
+ return data
+}
+
+export async function get(id: number): Promise {
+ const { data } = await apiClient.get(
+ `/admin/channel-monitor-templates/${id}`,
+ )
+ return data
+}
+
+export async function create(params: CreateParams): Promise {
+ const { data } = await apiClient.post(
+ '/admin/channel-monitor-templates',
+ params,
+ )
+ return data
+}
+
+export async function update(id: number, params: UpdateParams): Promise {
+ const { data } = await apiClient.put(
+ `/admin/channel-monitor-templates/${id}`,
+ params,
+ )
+ return data
+}
+
+export async function del(id: number): Promise {
+ await apiClient.delete(`/admin/channel-monitor-templates/${id}`)
+}
+
+/**
+ * Apply the template to the specified associated monitors (overwrite snapshot fields).
+ * monitorIds must be a non-empty subset of the template's associated monitors.
+ * Returns count of actually affected monitors.
+ */
+export async function apply(id: number, monitorIds: number[]): Promise {
+ const { data } = await apiClient.post(
+ `/admin/channel-monitor-templates/${id}/apply`,
+ { monitor_ids: monitorIds },
+ )
+ return data
+}
+
+/**
+ * List monitors currently associated to this template (used by apply picker).
+ */
+export async function listAssociatedMonitors(id: number): Promise {
+ const { data } = await apiClient.get(
+ `/admin/channel-monitor-templates/${id}/monitors`,
+ )
+ return data
+}
+
+export const channelMonitorTemplateAPI = {
+ list,
+ get,
+ create,
+ update,
+ del,
+ apply,
+ listAssociatedMonitors,
+}
+
+export default channelMonitorTemplateAPI
diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts
index f129ceaa..9d430134 100644
--- a/frontend/src/api/admin/channels.ts
+++ b/frontend/src/api/admin/channels.ts
@@ -4,8 +4,9 @@
*/
import { apiClient } from '../client'
+import type { BillingMode, ChannelStatus, BillingModelSource } from '@/constants/channel'
-export type BillingMode = 'token' | 'per_request' | 'image'
+export type { BillingMode } from '@/constants/channel'
export interface PricingInterval {
id?: number
@@ -46,8 +47,8 @@ export interface Channel {
id: number
name: string
description: string
- status: string
- billing_model_source: string // "requested" | "upstream"
+ status: ChannelStatus
+ billing_model_source: BillingModelSource
restrict_models: boolean
features_config?: Record
group_ids: number[]
diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts
index 8739d5cb..6b94b799 100644
--- a/frontend/src/api/admin/groups.ts
+++ b/frontend/src/api/admin/groups.ts
@@ -164,7 +164,8 @@ export interface GroupRateMultiplierEntry {
user_email: string
user_notes: string
user_status: string
- rate_multiplier: number
+ rate_multiplier?: number | null
+ rpm_override?: number | null
}
/**
@@ -205,9 +206,7 @@ export async function clearGroupRateMultipliers(id: number): Promise<{ message:
/**
* Batch set rate multipliers for users in a group
- * @param id - Group ID
- * @param entries - Array of { user_id, rate_multiplier }
- * @returns Success confirmation
+ * Only touches rate_multiplier column; preserves rpm_override on existing rows.
*/
export async function batchSetGroupRateMultipliers(
id: number,
@@ -220,6 +219,60 @@ export async function batchSetGroupRateMultipliers(
return data
}
+/**
+ * RPM override entry for a user in a group
+ */
+export interface GroupRPMOverrideEntry {
+ user_id: number
+ user_name: string
+ user_email: string
+ user_notes: string
+ user_status: string
+ rpm_override: number
+}
+
+/**
+ * Get RPM overrides for users in a group (subset of rate-multipliers endpoint).
+ */
+export async function getGroupRPMOverrides(id: number): Promise {
+ const { data } = await apiClient.get(
+ `/admin/groups/${id}/rate-multipliers`
+ )
+ return data
+ .filter(e => e.rpm_override != null)
+ .map(e => ({
+ user_id: e.user_id,
+ user_name: e.user_name,
+ user_email: e.user_email,
+ user_notes: e.user_notes,
+ user_status: e.user_status,
+ rpm_override: e.rpm_override as number
+ }))
+}
+
+/**
+ * Batch set RPM overrides for users in a group.
+ * Only touches rpm_override column; preserves rate_multiplier on existing rows.
+ */
+export async function batchSetGroupRPMOverrides(
+ id: number,
+ entries: Array<{ user_id: number; rpm_override: number }>
+): Promise<{ message: string }> {
+ const { data } = await apiClient.put<{ message: string }>(
+ `/admin/groups/${id}/rpm-overrides`,
+ { entries }
+ )
+ return data
+}
+
+/**
+ * Clear all RPM overrides for a group (preserves rate_multiplier).
+ */
+export async function clearGroupRPMOverrides(id: number): Promise<{ message: string }> {
+ const { data } = await apiClient.delete<{ message: string }>(`/admin/groups/${id}/rpm-overrides`)
+ return data
+}
+
/**
* Get usage summary (today + cumulative cost) for all groups
* @param timezone - IANA timezone string (e.g. "Asia/Shanghai")
@@ -262,6 +315,9 @@ export const groupsAPI = {
getGroupRateMultipliers,
clearGroupRateMultipliers,
batchSetGroupRateMultipliers,
+ getGroupRPMOverrides,
+ clearGroupRPMOverrides,
+ batchSetGroupRPMOverrides,
updateSortOrder,
getUsageSummary,
getCapacitySummary
diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts
index 72597365..80241794 100644
--- a/frontend/src/api/admin/index.ts
+++ b/frontend/src/api/admin/index.ts
@@ -26,7 +26,10 @@ import scheduledTestsAPI from './scheduledTests'
import backupAPI from './backup'
import tlsFingerprintProfileAPI from './tlsFingerprintProfile'
import channelsAPI from './channels'
+import channelMonitorAPI from './channelMonitor'
+import channelMonitorTemplateAPI from './channelMonitorTemplate'
import adminPaymentAPI from './payment'
+import affiliatesAPI from './affiliates'
/**
* Unified admin API object for convenient access
@@ -55,7 +58,10 @@ export const adminAPI = {
backup: backupAPI,
tlsFingerprintProfiles: tlsFingerprintProfileAPI,
channels: channelsAPI,
- payment: adminPaymentAPI
+ channelMonitor: channelMonitorAPI,
+ channelMonitorTemplate: channelMonitorTemplateAPI,
+ payment: adminPaymentAPI,
+ affiliates: affiliatesAPI
}
export {
@@ -82,7 +88,10 @@ export {
backupAPI,
tlsFingerprintProfileAPI,
channelsAPI,
- adminPaymentAPI
+ channelMonitorAPI,
+ channelMonitorTemplateAPI,
+ adminPaymentAPI,
+ affiliatesAPI
}
export default adminAPI
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 1e4a3053..35eef9de 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -3,12 +3,293 @@
* Handles system settings management for administrators
*/
-import { apiClient } from '../client'
-import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from '@/types'
+import { apiClient } from "../client";
+import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from "@/types";
export interface DefaultSubscriptionSetting {
- group_id: number
- validity_days: number
+ group_id: number;
+ validity_days: number;
+}
+
+export type AuthSourceType = "email" | "linuxdo" | "oidc" | "wechat";
+
+export interface AuthSourceDefaultsValue {
+ balance: number;
+ concurrency: number;
+ subscriptions: DefaultSubscriptionSetting[];
+ grant_on_signup: boolean;
+ grant_on_first_bind: boolean;
+}
+
+export type AuthSourceDefaultsState = Record<
+ AuthSourceType,
+ AuthSourceDefaultsValue
+>;
+export type PaymentVisibleMethod = "alipay" | "wxpay";
+export type PaymentVisibleMethodSource =
+ | ""
+ | "official_alipay"
+ | "easypay_alipay"
+ | "official_wxpay"
+ | "easypay_wxpay";
+export type WeChatConnectMode = "open" | "mp" | "mobile";
+
+export interface PaymentVisibleMethodSourceOption {
+ value: PaymentVisibleMethodSource;
+ labelZh: string;
+ labelEn: string;
+}
+
+export interface WeChatConnectModeOption {
+ value: WeChatConnectMode;
+ labelZh: string;
+ labelEn: string;
+}
+
+const AUTH_SOURCE_TYPES: AuthSourceType[] = [
+ "email",
+ "linuxdo",
+ "oidc",
+ "wechat",
+];
+const AUTH_SOURCE_DEFAULT_BALANCE = 0;
+const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5;
+const PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS: Record<
+ PaymentVisibleMethod,
+ PaymentVisibleMethodSourceOption[]
+> = {
+ alipay: [
+ { value: "", labelZh: "未配置", labelEn: "Not configured" },
+ {
+ value: "official_alipay",
+ labelZh: "支付宝官方",
+ labelEn: "Official Alipay",
+ },
+ {
+ value: "easypay_alipay",
+ labelZh: "易支付支付宝",
+ labelEn: "EasyPay Alipay",
+ },
+ ],
+ wxpay: [
+ { value: "", labelZh: "未配置", labelEn: "Not configured" },
+ {
+ value: "official_wxpay",
+ labelZh: "微信官方",
+ labelEn: "Official WeChat Pay",
+ },
+ {
+ value: "easypay_wxpay",
+ labelZh: "易支付微信",
+ labelEn: "EasyPay WeChat Pay",
+ },
+ ],
+};
+const PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES: Record<
+ PaymentVisibleMethod,
+ Record
+> = {
+ alipay: {
+ official_alipay: "official_alipay",
+ alipay: "official_alipay",
+ alipay_direct: "official_alipay",
+ official: "official_alipay",
+ easypay_alipay: "easypay_alipay",
+ easypay: "easypay_alipay",
+ },
+ wxpay: {
+ official_wxpay: "official_wxpay",
+ wxpay: "official_wxpay",
+ wxpay_direct: "official_wxpay",
+ wechat: "official_wxpay",
+ official: "official_wxpay",
+ easypay_wxpay: "easypay_wxpay",
+ easypay: "easypay_wxpay",
+ },
+};
+const WECHAT_CONNECT_MODE_OPTIONS: WeChatConnectModeOption[] = [
+ { value: "open", labelZh: "PC 应用", labelEn: "PC App" },
+ {
+ value: "mp",
+ labelZh: "公众号",
+ labelEn: "Official Account",
+ },
+ {
+ value: "mobile",
+ labelZh: "移动应用",
+ labelEn: "Mobile App",
+ },
+];
+const WECHAT_CONNECT_MODE_ALIASES: Record = {
+ open: "open",
+ open_platform: "open",
+ official: "open",
+ wx_open: "open",
+ mp: "mp",
+ official_account: "mp",
+ wechat_mp: "mp",
+ mini_program: "mp",
+ mobile: "mobile",
+ mobile_app: "mobile",
+ native_app: "mobile",
+};
+
+export function normalizeDefaultSubscriptionSettings(
+ subscriptions: DefaultSubscriptionSetting[] | null | undefined,
+): DefaultSubscriptionSetting[] {
+ if (!Array.isArray(subscriptions)) return [];
+
+ return subscriptions
+ .filter((item) => item.group_id > 0 && item.validity_days > 0)
+ .map((item) => ({
+ group_id: Math.floor(item.group_id),
+ validity_days: Math.min(
+ 36500,
+ Math.max(1, Math.floor(item.validity_days)),
+ ),
+ }));
+}
+
+export function buildAuthSourceDefaultsState(
+ settings: Partial,
+): AuthSourceDefaultsState {
+ const raw = settings as Record;
+
+ return AUTH_SOURCE_TYPES.reduce((acc, source) => {
+ const subscriptions = raw[`auth_source_default_${source}_subscriptions`];
+ acc[source] = {
+ balance: Number(
+ raw[`auth_source_default_${source}_balance`] ??
+ AUTH_SOURCE_DEFAULT_BALANCE,
+ ),
+ concurrency: Math.max(
+ 1,
+ Number(
+ raw[`auth_source_default_${source}_concurrency`] ??
+ AUTH_SOURCE_DEFAULT_CONCURRENCY,
+ ),
+ ),
+ subscriptions: normalizeDefaultSubscriptionSettings(
+ Array.isArray(subscriptions)
+ ? (subscriptions as DefaultSubscriptionSetting[])
+ : [],
+ ),
+ grant_on_signup:
+ raw[`auth_source_default_${source}_grant_on_signup`] === true,
+ grant_on_first_bind:
+ raw[`auth_source_default_${source}_grant_on_first_bind`] === true,
+ };
+ return acc;
+ }, {} as AuthSourceDefaultsState);
+}
+
+export function appendAuthSourceDefaultsToUpdateRequest(
+ payload: UpdateSettingsRequest,
+ authSourceDefaults: AuthSourceDefaultsState,
+): UpdateSettingsRequest {
+ const target = payload as Record;
+
+ for (const source of AUTH_SOURCE_TYPES) {
+ const current = authSourceDefaults[source];
+ target[`auth_source_default_${source}_balance`] =
+ Number(current.balance) || 0;
+ target[`auth_source_default_${source}_concurrency`] = Math.max(
+ 1,
+ Math.floor(
+ Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY,
+ ),
+ );
+ target[`auth_source_default_${source}_subscriptions`] =
+ normalizeDefaultSubscriptionSettings(current.subscriptions);
+ target[`auth_source_default_${source}_grant_on_signup`] =
+ current.grant_on_signup;
+ target[`auth_source_default_${source}_grant_on_first_bind`] =
+ current.grant_on_first_bind;
+ }
+
+ return payload;
+}
+
+export function getPaymentVisibleMethodSourceOptions(
+ method: PaymentVisibleMethod,
+): PaymentVisibleMethodSourceOption[] {
+ return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method];
+}
+
+export function normalizePaymentVisibleMethodSource(
+ method: PaymentVisibleMethod,
+ source: unknown,
+): PaymentVisibleMethodSource {
+ if (typeof source !== "string") return "";
+
+ const normalized = source.trim().toLowerCase();
+ if (!normalized) return "";
+
+ return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? "";
+}
+
+export function getWeChatConnectModeOptions(): WeChatConnectModeOption[] {
+ return WECHAT_CONNECT_MODE_OPTIONS;
+}
+
+export function normalizeWeChatConnectMode(source: unknown): WeChatConnectMode {
+ if (typeof source !== "string") return "open";
+
+ const normalized = source.trim().toLowerCase();
+ if (!normalized) return "open";
+
+ return WECHAT_CONNECT_MODE_ALIASES[normalized] ?? "open";
+}
+
+export function defaultWeChatConnectScopesForMode(mode: unknown): string {
+ switch (normalizeWeChatConnectMode(mode)) {
+ case "mp":
+ return "snsapi_userinfo";
+ case "mobile":
+ return "";
+ default:
+ return "snsapi_login";
+ }
+}
+
+export function resolveWeChatConnectModeCapabilities(
+ openEnabled: unknown,
+ mpEnabled: unknown,
+ mobileEnabled: unknown,
+ legacyMode: unknown,
+): { openEnabled: boolean; mpEnabled: boolean; mobileEnabled: boolean } {
+ if (
+ typeof openEnabled === "boolean" ||
+ typeof mpEnabled === "boolean" ||
+ typeof mobileEnabled === "boolean"
+ ) {
+ return {
+ openEnabled: openEnabled === true,
+ mpEnabled: mpEnabled === true,
+ mobileEnabled: mobileEnabled === true,
+ };
+ }
+
+ switch (normalizeWeChatConnectMode(legacyMode)) {
+ case "mp":
+ return { openEnabled: false, mpEnabled: true, mobileEnabled: false };
+ case "mobile":
+ return { openEnabled: false, mpEnabled: false, mobileEnabled: true };
+ default:
+ return { openEnabled: true, mpEnabled: false, mobileEnabled: false };
+ }
+}
+
+export function deriveWeChatConnectStoredMode(
+ openEnabled: boolean,
+ mpEnabled: boolean,
+ mobileEnabled: boolean,
+ legacyMode: unknown,
+): WeChatConnectMode {
+ if (mpEnabled) return "mp";
+ if (mobileEnabled) return "mobile";
+ if (openEnabled) return "open";
+ return normalizeWeChatConnectMode(legacyMode);
}
/**
@@ -16,241 +297,365 @@ export interface DefaultSubscriptionSetting {
*/
export interface SystemSettings {
// Registration settings
- registration_enabled: boolean
- email_verify_enabled: boolean
- registration_email_suffix_whitelist: string[]
- promo_code_enabled: boolean
- password_reset_enabled: boolean
- frontend_url: string
- invitation_code_enabled: boolean
- totp_enabled: boolean // TOTP 双因素认证
- totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置
+ registration_enabled: boolean;
+ email_verify_enabled: boolean;
+ registration_email_suffix_whitelist: string[];
+ promo_code_enabled: boolean;
+ password_reset_enabled: boolean;
+ frontend_url: string;
+ invitation_code_enabled: boolean;
+ totp_enabled: boolean; // TOTP 双因素认证
+ totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置
// Default settings
- default_balance: number
- default_concurrency: number
- default_subscriptions: DefaultSubscriptionSetting[]
+ default_balance: number;
+ affiliate_rebate_rate: number;
+ affiliate_rebate_freeze_hours: number;
+ affiliate_rebate_duration_days: number;
+ affiliate_rebate_per_invitee_cap: number;
+ default_concurrency: number;
+ default_user_rpm_limit: number;
+ default_subscriptions: DefaultSubscriptionSetting[];
+ auth_source_default_email_balance?: number;
+ auth_source_default_email_concurrency?: number;
+ auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_email_grant_on_signup?: boolean;
+ auth_source_default_email_grant_on_first_bind?: boolean;
+ auth_source_default_linuxdo_balance?: number;
+ auth_source_default_linuxdo_concurrency?: number;
+ auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_linuxdo_grant_on_signup?: boolean;
+ auth_source_default_linuxdo_grant_on_first_bind?: boolean;
+ auth_source_default_oidc_balance?: number;
+ auth_source_default_oidc_concurrency?: number;
+ auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_oidc_grant_on_signup?: boolean;
+ auth_source_default_oidc_grant_on_first_bind?: boolean;
+ auth_source_default_wechat_balance?: number;
+ auth_source_default_wechat_concurrency?: number;
+ auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_wechat_grant_on_signup?: boolean;
+ auth_source_default_wechat_grant_on_first_bind?: boolean;
+ force_email_on_third_party_signup?: boolean;
// OEM settings
- site_name: string
- site_logo: string
- site_subtitle: string
- api_base_url: string
- contact_info: string
- doc_url: string
- home_content: string
- hide_ccs_import_button: boolean
- table_default_page_size: number
- table_page_size_options: number[]
- backend_mode_enabled: boolean
- custom_menu_items: CustomMenuItem[]
- custom_endpoints: CustomEndpoint[]
+ site_name: string;
+ site_logo: string;
+ site_subtitle: string;
+ api_base_url: string;
+ contact_info: string;
+ doc_url: string;
+ home_content: string;
+ hide_ccs_import_button: boolean;
+ table_default_page_size: number;
+ table_page_size_options: number[];
+ backend_mode_enabled: boolean;
+ custom_menu_items: CustomMenuItem[];
+ custom_endpoints: CustomEndpoint[];
// SMTP settings
- smtp_host: string
- smtp_port: number
- smtp_username: string
- smtp_password_configured: boolean
- smtp_from_email: string
- smtp_from_name: string
- smtp_use_tls: boolean
+ smtp_host: string;
+ smtp_port: number;
+ smtp_username: string;
+ smtp_password_configured: boolean;
+ smtp_from_email: string;
+ smtp_from_name: string;
+ smtp_use_tls: boolean;
// Cloudflare Turnstile settings
- turnstile_enabled: boolean
- turnstile_site_key: string
- turnstile_secret_key_configured: boolean
+ turnstile_enabled: boolean;
+ turnstile_site_key: string;
+ turnstile_secret_key_configured: boolean;
// LinuxDo Connect OAuth settings
- linuxdo_connect_enabled: boolean
- linuxdo_connect_client_id: string
- linuxdo_connect_client_secret_configured: boolean
- linuxdo_connect_redirect_url: string
+ linuxdo_connect_enabled: boolean;
+ linuxdo_connect_client_id: string;
+ linuxdo_connect_client_secret_configured: boolean;
+ linuxdo_connect_redirect_url: string;
+
+ // WeChat Connect OAuth settings
+ wechat_connect_enabled: boolean;
+ wechat_connect_app_id: string;
+ wechat_connect_app_secret_configured: boolean;
+ wechat_connect_open_app_id?: string;
+ wechat_connect_open_app_secret_configured?: boolean;
+ wechat_connect_mp_app_id?: string;
+ wechat_connect_mp_app_secret_configured?: boolean;
+ wechat_connect_mobile_app_id?: string;
+ wechat_connect_mobile_app_secret_configured?: boolean;
+ wechat_connect_open_enabled?: boolean;
+ wechat_connect_mp_enabled?: boolean;
+ wechat_connect_mobile_enabled?: boolean;
+ wechat_connect_mode: string;
+ wechat_connect_scopes: string;
+ wechat_connect_redirect_url: string;
+ wechat_connect_frontend_redirect_url: string;
// Generic OIDC OAuth settings
- oidc_connect_enabled: boolean
- oidc_connect_provider_name: string
- oidc_connect_client_id: string
- oidc_connect_client_secret_configured: boolean
- oidc_connect_issuer_url: string
- oidc_connect_discovery_url: string
- oidc_connect_authorize_url: string
- oidc_connect_token_url: string
- oidc_connect_userinfo_url: string
- oidc_connect_jwks_url: string
- oidc_connect_scopes: string
- oidc_connect_redirect_url: string
- oidc_connect_frontend_redirect_url: string
- oidc_connect_token_auth_method: string
- oidc_connect_use_pkce: boolean
- oidc_connect_validate_id_token: boolean
- oidc_connect_allowed_signing_algs: string
- oidc_connect_clock_skew_seconds: number
- oidc_connect_require_email_verified: boolean
- oidc_connect_userinfo_email_path: string
- oidc_connect_userinfo_id_path: string
- oidc_connect_userinfo_username_path: string
+ oidc_connect_enabled: boolean;
+ oidc_connect_provider_name: string;
+ oidc_connect_client_id: string;
+ oidc_connect_client_secret_configured: boolean;
+ oidc_connect_issuer_url: string;
+ oidc_connect_discovery_url: string;
+ oidc_connect_authorize_url: string;
+ oidc_connect_token_url: string;
+ oidc_connect_userinfo_url: string;
+ oidc_connect_jwks_url: string;
+ oidc_connect_scopes: string;
+ oidc_connect_redirect_url: string;
+ oidc_connect_frontend_redirect_url: string;
+ oidc_connect_token_auth_method: string;
+ oidc_connect_use_pkce: boolean;
+ oidc_connect_validate_id_token: boolean;
+ oidc_connect_allowed_signing_algs: string;
+ oidc_connect_clock_skew_seconds: number;
+ oidc_connect_require_email_verified: boolean;
+ oidc_connect_userinfo_email_path: string;
+ oidc_connect_userinfo_id_path: string;
+ oidc_connect_userinfo_username_path: string;
// Model fallback configuration
- enable_model_fallback: boolean
- fallback_model_anthropic: string
- fallback_model_openai: string
- fallback_model_gemini: string
- fallback_model_antigravity: string
+ enable_model_fallback: boolean;
+ fallback_model_anthropic: string;
+ fallback_model_openai: string;
+ fallback_model_gemini: string;
+ fallback_model_antigravity: string;
// Identity patch configuration (Claude -> Gemini)
- enable_identity_patch: boolean
- identity_patch_prompt: string
+ enable_identity_patch: boolean;
+ identity_patch_prompt: string;
// Ops Monitoring (vNext)
- ops_monitoring_enabled: boolean
- ops_realtime_monitoring_enabled: boolean
- ops_query_mode_default: 'auto' | 'raw' | 'preagg' | string
- ops_metrics_interval_seconds: number
+ ops_monitoring_enabled: boolean;
+ ops_realtime_monitoring_enabled: boolean;
+ ops_query_mode_default: "auto" | "raw" | "preagg" | string;
+ ops_metrics_interval_seconds: number;
// Claude Code version check
- min_claude_code_version: string
- max_claude_code_version: string
+ min_claude_code_version: string;
+ max_claude_code_version: string;
// 分组隔离
- allow_ungrouped_key_scheduling: boolean
+ allow_ungrouped_key_scheduling: boolean;
// Gateway forwarding behavior
- enable_fingerprint_unification: boolean
- enable_metadata_passthrough: boolean
- enable_cch_signing: boolean
- web_search_emulation_enabled?: boolean
+ enable_fingerprint_unification: boolean;
+ enable_metadata_passthrough: boolean;
+ enable_cch_signing: boolean;
+ enable_anthropic_cache_ttl_1h_injection: boolean;
+ web_search_emulation_enabled?: boolean;
// Payment configuration
- payment_enabled: boolean
- payment_min_amount: number
- payment_max_amount: number
- payment_daily_limit: number
- payment_order_timeout_minutes: number
- payment_max_pending_orders: number
- payment_enabled_types: string[]
- payment_balance_disabled: boolean
- payment_balance_recharge_multiplier: number
- payment_recharge_fee_rate: number
- payment_load_balance_strategy: string
- payment_product_name_prefix: string
- payment_product_name_suffix: string
- payment_help_image_url: string
- payment_help_text: string
- payment_cancel_rate_limit_enabled: boolean
- payment_cancel_rate_limit_max: number
- payment_cancel_rate_limit_window: number
- payment_cancel_rate_limit_unit: string
- payment_cancel_rate_limit_window_mode: string
+ payment_enabled: boolean;
+ payment_min_amount: number;
+ payment_max_amount: number;
+ payment_daily_limit: number;
+ payment_order_timeout_minutes: number;
+ payment_max_pending_orders: number;
+ payment_enabled_types: string[];
+ payment_balance_disabled: boolean;
+ payment_balance_recharge_multiplier: number;
+ payment_recharge_fee_rate: number;
+ payment_load_balance_strategy: string;
+ payment_product_name_prefix: string;
+ payment_product_name_suffix: string;
+ payment_help_image_url: string;
+ payment_help_text: string;
+ payment_cancel_rate_limit_enabled: boolean;
+ payment_cancel_rate_limit_max: number;
+ payment_cancel_rate_limit_window: number;
+ payment_cancel_rate_limit_unit: string;
+ payment_cancel_rate_limit_window_mode: string;
+ payment_visible_method_alipay_source?: string;
+ payment_visible_method_wxpay_source?: string;
+ payment_visible_method_alipay_enabled?: boolean;
+ payment_visible_method_wxpay_enabled?: boolean;
+ openai_advanced_scheduler_enabled?: boolean;
// Balance & quota notification
- balance_low_notify_enabled: boolean
- balance_low_notify_threshold: number
- balance_low_notify_recharge_url: string
- account_quota_notify_enabled: boolean
- account_quota_notify_emails: NotifyEmailEntry[]
+ balance_low_notify_enabled: boolean;
+ balance_low_notify_threshold: number;
+ balance_low_notify_recharge_url: string;
+ account_quota_notify_enabled: boolean;
+ account_quota_notify_emails: NotifyEmailEntry[];
+
+ // Channel Monitor feature switch
+ channel_monitor_enabled: boolean;
+ channel_monitor_default_interval_seconds: number;
+
+ // Available Channels feature switch
+ available_channels_enabled: boolean;
+
+ // Affiliate (邀请返利) feature switch
+ affiliate_enabled: boolean;
+
+ // OpenAI fast/flex policy
+ openai_fast_policy_settings?: OpenAIFastPolicySettings;
}
export interface UpdateSettingsRequest {
- registration_enabled?: boolean
- email_verify_enabled?: boolean
- registration_email_suffix_whitelist?: string[]
- promo_code_enabled?: boolean
- password_reset_enabled?: boolean
- frontend_url?: string
- invitation_code_enabled?: boolean
- totp_enabled?: boolean // TOTP 双因素认证
- default_balance?: number
- default_concurrency?: number
- default_subscriptions?: DefaultSubscriptionSetting[]
- site_name?: string
- site_logo?: string
- site_subtitle?: string
- api_base_url?: string
- contact_info?: string
- doc_url?: string
- home_content?: string
- hide_ccs_import_button?: boolean
- table_default_page_size?: number
- table_page_size_options?: number[]
- backend_mode_enabled?: boolean
- custom_menu_items?: CustomMenuItem[]
- custom_endpoints?: CustomEndpoint[]
- smtp_host?: string
- smtp_port?: number
- smtp_username?: string
- smtp_password?: string
- smtp_from_email?: string
- smtp_from_name?: string
- smtp_use_tls?: boolean
- turnstile_enabled?: boolean
- turnstile_site_key?: string
- turnstile_secret_key?: string
- linuxdo_connect_enabled?: boolean
- linuxdo_connect_client_id?: string
- linuxdo_connect_client_secret?: string
- linuxdo_connect_redirect_url?: string
- oidc_connect_enabled?: boolean
- oidc_connect_provider_name?: string
- oidc_connect_client_id?: string
- oidc_connect_client_secret?: string
- oidc_connect_issuer_url?: string
- oidc_connect_discovery_url?: string
- oidc_connect_authorize_url?: string
- oidc_connect_token_url?: string
- oidc_connect_userinfo_url?: string
- oidc_connect_jwks_url?: string
- oidc_connect_scopes?: string
- oidc_connect_redirect_url?: string
- oidc_connect_frontend_redirect_url?: string
- oidc_connect_token_auth_method?: string
- oidc_connect_use_pkce?: boolean
- oidc_connect_validate_id_token?: boolean
- oidc_connect_allowed_signing_algs?: string
- oidc_connect_clock_skew_seconds?: number
- oidc_connect_require_email_verified?: boolean
- oidc_connect_userinfo_email_path?: string
- oidc_connect_userinfo_id_path?: string
- oidc_connect_userinfo_username_path?: string
- enable_model_fallback?: boolean
- fallback_model_anthropic?: string
- fallback_model_openai?: string
- fallback_model_gemini?: string
- fallback_model_antigravity?: string
- enable_identity_patch?: boolean
- identity_patch_prompt?: string
- ops_monitoring_enabled?: boolean
- ops_realtime_monitoring_enabled?: boolean
- ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string
- ops_metrics_interval_seconds?: number
- min_claude_code_version?: string
- max_claude_code_version?: string
- allow_ungrouped_key_scheduling?: boolean
- enable_fingerprint_unification?: boolean
- enable_metadata_passthrough?: boolean
- enable_cch_signing?: boolean
+ registration_enabled?: boolean;
+ email_verify_enabled?: boolean;
+ registration_email_suffix_whitelist?: string[];
+ promo_code_enabled?: boolean;
+ password_reset_enabled?: boolean;
+ frontend_url?: string;
+ invitation_code_enabled?: boolean;
+ totp_enabled?: boolean; // TOTP 双因素认证
+ default_balance?: number;
+ affiliate_rebate_rate?: number;
+ affiliate_rebate_freeze_hours?: number;
+ affiliate_rebate_duration_days?: number;
+ affiliate_rebate_per_invitee_cap?: number;
+ default_concurrency?: number;
+ default_user_rpm_limit?: number;
+ default_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_email_balance?: number;
+ auth_source_default_email_concurrency?: number;
+ auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_email_grant_on_signup?: boolean;
+ auth_source_default_email_grant_on_first_bind?: boolean;
+ auth_source_default_linuxdo_balance?: number;
+ auth_source_default_linuxdo_concurrency?: number;
+ auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_linuxdo_grant_on_signup?: boolean;
+ auth_source_default_linuxdo_grant_on_first_bind?: boolean;
+ auth_source_default_oidc_balance?: number;
+ auth_source_default_oidc_concurrency?: number;
+ auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_oidc_grant_on_signup?: boolean;
+ auth_source_default_oidc_grant_on_first_bind?: boolean;
+ auth_source_default_wechat_balance?: number;
+ auth_source_default_wechat_concurrency?: number;
+ auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_wechat_grant_on_signup?: boolean;
+ auth_source_default_wechat_grant_on_first_bind?: boolean;
+ force_email_on_third_party_signup?: boolean;
+ site_name?: string;
+ site_logo?: string;
+ site_subtitle?: string;
+ api_base_url?: string;
+ contact_info?: string;
+ doc_url?: string;
+ home_content?: string;
+ hide_ccs_import_button?: boolean;
+ table_default_page_size?: number;
+ table_page_size_options?: number[];
+ backend_mode_enabled?: boolean;
+ custom_menu_items?: CustomMenuItem[];
+ custom_endpoints?: CustomEndpoint[];
+ smtp_host?: string;
+ smtp_port?: number;
+ smtp_username?: string;
+ smtp_password?: string;
+ smtp_from_email?: string;
+ smtp_from_name?: string;
+ smtp_use_tls?: boolean;
+ turnstile_enabled?: boolean;
+ turnstile_site_key?: string;
+ turnstile_secret_key?: string;
+ linuxdo_connect_enabled?: boolean;
+ linuxdo_connect_client_id?: string;
+ linuxdo_connect_client_secret?: string;
+ linuxdo_connect_redirect_url?: string;
+ wechat_connect_enabled?: boolean;
+ wechat_connect_app_id?: string;
+ wechat_connect_app_secret?: string;
+ wechat_connect_open_app_id?: string;
+ wechat_connect_open_app_secret?: string;
+ wechat_connect_mp_app_id?: string;
+ wechat_connect_mp_app_secret?: string;
+ wechat_connect_mobile_app_id?: string;
+ wechat_connect_mobile_app_secret?: string;
+ wechat_connect_open_enabled?: boolean;
+ wechat_connect_mp_enabled?: boolean;
+ wechat_connect_mobile_enabled?: boolean;
+ wechat_connect_mode?: string;
+ wechat_connect_scopes?: string;
+ wechat_connect_redirect_url?: string;
+ wechat_connect_frontend_redirect_url?: string;
+ oidc_connect_enabled?: boolean;
+ oidc_connect_provider_name?: string;
+ oidc_connect_client_id?: string;
+ oidc_connect_client_secret?: string;
+ oidc_connect_issuer_url?: string;
+ oidc_connect_discovery_url?: string;
+ oidc_connect_authorize_url?: string;
+ oidc_connect_token_url?: string;
+ oidc_connect_userinfo_url?: string;
+ oidc_connect_jwks_url?: string;
+ oidc_connect_scopes?: string;
+ oidc_connect_redirect_url?: string;
+ oidc_connect_frontend_redirect_url?: string;
+ oidc_connect_token_auth_method?: string;
+ oidc_connect_use_pkce?: boolean;
+ oidc_connect_validate_id_token?: boolean;
+ oidc_connect_allowed_signing_algs?: string;
+ oidc_connect_clock_skew_seconds?: number;
+ oidc_connect_require_email_verified?: boolean;
+ oidc_connect_userinfo_email_path?: string;
+ oidc_connect_userinfo_id_path?: string;
+ oidc_connect_userinfo_username_path?: string;
+ enable_model_fallback?: boolean;
+ fallback_model_anthropic?: string;
+ fallback_model_openai?: string;
+ fallback_model_gemini?: string;
+ fallback_model_antigravity?: string;
+ enable_identity_patch?: boolean;
+ identity_patch_prompt?: string;
+ ops_monitoring_enabled?: boolean;
+ ops_realtime_monitoring_enabled?: boolean;
+ ops_query_mode_default?: "auto" | "raw" | "preagg" | string;
+ ops_metrics_interval_seconds?: number;
+ min_claude_code_version?: string;
+ max_claude_code_version?: string;
+ allow_ungrouped_key_scheduling?: boolean;
+ enable_fingerprint_unification?: boolean;
+ enable_metadata_passthrough?: boolean;
+ enable_cch_signing?: boolean;
+ enable_anthropic_cache_ttl_1h_injection?: boolean;
// Payment configuration
- payment_enabled?: boolean
- payment_min_amount?: number
- payment_max_amount?: number
- payment_daily_limit?: number
- payment_order_timeout_minutes?: number
- payment_max_pending_orders?: number
- payment_enabled_types?: string[]
- payment_balance_disabled?: boolean
- payment_balance_recharge_multiplier?: number
- payment_recharge_fee_rate?: number
- payment_load_balance_strategy?: string
- payment_product_name_prefix?: string
- payment_product_name_suffix?: string
- payment_help_image_url?: string
- payment_help_text?: string
- payment_cancel_rate_limit_enabled?: boolean
- payment_cancel_rate_limit_max?: number
- payment_cancel_rate_limit_window?: number
- payment_cancel_rate_limit_unit?: string
- payment_cancel_rate_limit_window_mode?: string
+ payment_enabled?: boolean;
+ payment_min_amount?: number;
+ payment_max_amount?: number;
+ payment_daily_limit?: number;
+ payment_order_timeout_minutes?: number;
+ payment_max_pending_orders?: number;
+ payment_enabled_types?: string[];
+ payment_balance_disabled?: boolean;
+ payment_balance_recharge_multiplier?: number;
+ payment_recharge_fee_rate?: number;
+ payment_load_balance_strategy?: string;
+ payment_product_name_prefix?: string;
+ payment_product_name_suffix?: string;
+ payment_help_image_url?: string;
+ payment_help_text?: string;
+ payment_cancel_rate_limit_enabled?: boolean;
+ payment_cancel_rate_limit_max?: number;
+ payment_cancel_rate_limit_window?: number;
+ payment_cancel_rate_limit_unit?: string;
+ payment_cancel_rate_limit_window_mode?: string;
+ payment_visible_method_alipay_source?: string;
+ payment_visible_method_wxpay_source?: string;
+ payment_visible_method_alipay_enabled?: boolean;
+ payment_visible_method_wxpay_enabled?: boolean;
+ openai_advanced_scheduler_enabled?: boolean;
// Balance & quota notification
- balance_low_notify_enabled?: boolean
- balance_low_notify_threshold?: number
- balance_low_notify_recharge_url?: string
- account_quota_notify_enabled?: boolean
- account_quota_notify_emails?: NotifyEmailEntry[]
+ balance_low_notify_enabled?: boolean;
+ balance_low_notify_threshold?: number;
+ balance_low_notify_recharge_url?: string;
+ account_quota_notify_enabled?: boolean;
+ account_quota_notify_emails?: NotifyEmailEntry[];
+
+ // Channel Monitor feature switch
+ channel_monitor_enabled?: boolean;
+ channel_monitor_default_interval_seconds?: number;
+
+ // Available Channels feature switch
+ available_channels_enabled?: boolean;
+
+ // Affiliate (邀请返利) feature switch
+ affiliate_enabled?: boolean;
+
+ // OpenAI fast/flex policy
+ openai_fast_policy_settings?: OpenAIFastPolicySettings;
}
/**
@@ -258,8 +663,8 @@ export interface UpdateSettingsRequest {
* @returns System settings
*/
export async function getSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings')
- return data
+ const { data } = await apiClient.get("/admin/settings");
+ return data;
}
/**
@@ -267,20 +672,25 @@ export async function getSettings(): Promise {
* @param settings - Partial settings to update
* @returns Updated settings
*/
-export async function updateSettings(settings: UpdateSettingsRequest): Promise {
- const { data } = await apiClient.put('/admin/settings', settings)
- return data
+export async function updateSettings(
+ settings: UpdateSettingsRequest,
+): Promise {
+ const { data } = await apiClient.put(
+ "/admin/settings",
+ settings,
+ );
+ return data;
}
/**
* Test SMTP connection request
*/
export interface TestSmtpRequest {
- smtp_host: string
- smtp_port: number
- smtp_username: string
- smtp_password: string
- smtp_use_tls: boolean
+ smtp_host: string;
+ smtp_port: number;
+ smtp_username: string;
+ smtp_password: string;
+ smtp_use_tls: boolean;
}
/**
@@ -288,23 +698,28 @@ export interface TestSmtpRequest {
* @param config - SMTP configuration to test
* @returns Test result message
*/
-export async function testSmtpConnection(config: TestSmtpRequest): Promise<{ message: string }> {
- const { data } = await apiClient.post<{ message: string }>('/admin/settings/test-smtp', config)
- return data
+export async function testSmtpConnection(
+ config: TestSmtpRequest,
+): Promise<{ message: string }> {
+ const { data } = await apiClient.post<{ message: string }>(
+ "/admin/settings/test-smtp",
+ config,
+ );
+ return data;
}
/**
* Send test email request
*/
export interface SendTestEmailRequest {
- email: string
- smtp_host: string
- smtp_port: number
- smtp_username: string
- smtp_password: string
- smtp_from_email: string
- smtp_from_name: string
- smtp_use_tls: boolean
+ email: string;
+ smtp_host: string;
+ smtp_port: number;
+ smtp_username: string;
+ smtp_password: string;
+ smtp_from_email: string;
+ smtp_from_name: string;
+ smtp_use_tls: boolean;
}
/**
@@ -312,20 +727,22 @@ export interface SendTestEmailRequest {
* @param request - Email address and SMTP config
* @returns Test result message
*/
-export async function sendTestEmail(request: SendTestEmailRequest): Promise<{ message: string }> {
+export async function sendTestEmail(
+ request: SendTestEmailRequest,
+): Promise<{ message: string }> {
const { data } = await apiClient.post<{ message: string }>(
- '/admin/settings/send-test-email',
- request
- )
- return data
+ "/admin/settings/send-test-email",
+ request,
+ );
+ return data;
}
/**
* Admin API Key status response
*/
export interface AdminApiKeyStatus {
- exists: boolean
- masked_key: string
+ exists: boolean;
+ masked_key: string;
}
/**
@@ -333,8 +750,10 @@ export interface AdminApiKeyStatus {
* @returns Status indicating if key exists and masked version
*/
export async function getAdminApiKey(): Promise {
- const { data } = await apiClient.get('/admin/settings/admin-api-key')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/admin-api-key",
+ );
+ return data;
}
/**
@@ -342,8 +761,10 @@ export async function getAdminApiKey(): Promise {
* @returns The new full API key (only shown once)
*/
export async function regenerateAdminApiKey(): Promise<{ key: string }> {
- const { data } = await apiClient.post<{ key: string }>('/admin/settings/admin-api-key/regenerate')
- return data
+ const { data } = await apiClient.post<{ key: string }>(
+ "/admin/settings/admin-api-key/regenerate",
+ );
+ return data;
}
/**
@@ -351,8 +772,10 @@ export async function regenerateAdminApiKey(): Promise<{ key: string }> {
* @returns Success message
*/
export async function deleteAdminApiKey(): Promise<{ message: string }> {
- const { data } = await apiClient.delete<{ message: string }>('/admin/settings/admin-api-key')
- return data
+ const { data } = await apiClient.delete<{ message: string }>(
+ "/admin/settings/admin-api-key",
+ );
+ return data;
}
// ==================== Overload Cooldown Settings ====================
@@ -361,23 +784,25 @@ export async function deleteAdminApiKey(): Promise<{ message: string }> {
* Overload cooldown settings interface (529 handling)
*/
export interface OverloadCooldownSettings {
- enabled: boolean
- cooldown_minutes: number
+ enabled: boolean;
+ cooldown_minutes: number;
}
export async function getOverloadCooldownSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/overload-cooldown')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/overload-cooldown",
+ );
+ return data;
}
export async function updateOverloadCooldownSettings(
- settings: OverloadCooldownSettings
+ settings: OverloadCooldownSettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/overload-cooldown',
- settings
- )
- return data
+ "/admin/settings/overload-cooldown",
+ settings,
+ );
+ return data;
}
// ==================== Stream Timeout Settings ====================
@@ -386,11 +811,11 @@ export async function updateOverloadCooldownSettings(
* Stream timeout settings interface
*/
export interface StreamTimeoutSettings {
- enabled: boolean
- action: 'temp_unsched' | 'error' | 'none'
- temp_unsched_minutes: number
- threshold_count: number
- threshold_window_minutes: number
+ enabled: boolean;
+ action: "temp_unsched" | "error" | "none";
+ temp_unsched_minutes: number;
+ threshold_count: number;
+ threshold_window_minutes: number;
}
/**
@@ -398,8 +823,10 @@ export interface StreamTimeoutSettings {
* @returns Stream timeout settings
*/
export async function getStreamTimeoutSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/stream-timeout')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/stream-timeout",
+ );
+ return data;
}
/**
@@ -408,13 +835,13 @@ export async function getStreamTimeoutSettings(): Promise
* @returns Updated settings
*/
export async function updateStreamTimeoutSettings(
- settings: StreamTimeoutSettings
+ settings: StreamTimeoutSettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/stream-timeout',
- settings
- )
- return data
+ "/admin/settings/stream-timeout",
+ settings,
+ );
+ return data;
}
// ==================== Rectifier Settings ====================
@@ -423,11 +850,11 @@ export async function updateStreamTimeoutSettings(
* Rectifier settings interface
*/
export interface RectifierSettings {
- enabled: boolean
- thinking_signature_enabled: boolean
- thinking_budget_enabled: boolean
- apikey_signature_enabled: boolean
- apikey_signature_patterns: string[]
+ enabled: boolean;
+ thinking_signature_enabled: boolean;
+ thinking_budget_enabled: boolean;
+ apikey_signature_enabled: boolean;
+ apikey_signature_patterns: string[];
}
/**
@@ -435,8 +862,10 @@ export interface RectifierSettings {
* @returns Rectifier settings
*/
export async function getRectifierSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/rectifier')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/rectifier",
+ );
+ return data;
}
/**
@@ -445,13 +874,36 @@ export async function getRectifierSettings(): Promise {
* @returns Updated settings
*/
export async function updateRectifierSettings(
- settings: RectifierSettings
+ settings: RectifierSettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/rectifier',
- settings
- )
- return data
+ "/admin/settings/rectifier",
+ settings,
+ );
+ return data;
+}
+
+// ==================== OpenAI Fast Policy Settings ====================
+
+/**
+ * OpenAI fast/flex policy rule interface.
+ * Matches backend dto.OpenAIFastPolicyRule.
+ */
+export interface OpenAIFastPolicyRule {
+ service_tier: "all" | "priority" | "flex";
+ action: "pass" | "filter" | "block";
+ scope: "all" | "oauth" | "apikey" | "bedrock";
+ error_message?: string;
+ model_whitelist?: string[];
+ fallback_action?: "pass" | "filter" | "block";
+ fallback_error_message?: string;
+}
+
+/**
+ * OpenAI fast/flex policy settings interface.
+ */
+export interface OpenAIFastPolicySettings {
+ rules: OpenAIFastPolicyRule[];
}
// ==================== Beta Policy Settings ====================
@@ -460,20 +912,20 @@ export async function updateRectifierSettings(
* Beta policy rule interface
*/
export interface BetaPolicyRule {
- beta_token: string
- action: 'pass' | 'filter' | 'block'
- scope: 'all' | 'oauth' | 'apikey' | 'bedrock'
- error_message?: string
- model_whitelist?: string[]
- fallback_action?: 'pass' | 'filter' | 'block'
- fallback_error_message?: string
+ beta_token: string;
+ action: "pass" | "filter" | "block";
+ scope: "all" | "oauth" | "apikey" | "bedrock";
+ error_message?: string;
+ model_whitelist?: string[];
+ fallback_action?: "pass" | "filter" | "block";
+ fallback_error_message?: string;
}
/**
* Beta policy settings interface
*/
export interface BetaPolicySettings {
- rules: BetaPolicyRule[]
+ rules: BetaPolicyRule[];
}
/**
@@ -481,8 +933,10 @@ export interface BetaPolicySettings {
* @returns Beta policy settings
*/
export async function getBetaPolicySettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/beta-policy')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/beta-policy",
+ );
+ return data;
}
/**
@@ -491,70 +945,73 @@ export async function getBetaPolicySettings(): Promise {
* @returns Updated settings
*/
export async function updateBetaPolicySettings(
- settings: BetaPolicySettings
+ settings: BetaPolicySettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/beta-policy',
- settings
- )
- return data
+ "/admin/settings/beta-policy",
+ settings,
+ );
+ return data;
}
// --- Web Search Emulation Config ---
export interface WebSearchProviderConfig {
- type: 'brave' | 'tavily'
- api_key: string
- api_key_configured: boolean
- quota_limit: number | null
- subscribed_at: number | null
- quota_used?: number
- proxy_id: number | null
- expires_at: number | null
+ type: "brave" | "tavily";
+ api_key: string;
+ api_key_configured: boolean;
+ quota_limit: number | null;
+ subscribed_at: number | null;
+ quota_used?: number;
+ proxy_id: number | null;
+ expires_at: number | null;
}
export interface WebSearchEmulationConfig {
- enabled: boolean
- providers: WebSearchProviderConfig[]
+ enabled: boolean;
+ providers: WebSearchProviderConfig[];
}
export interface WebSearchTestResult {
- provider: string
- results: { url: string; title: string; snippet: string; page_age?: string }[]
- query: string
+ provider: string;
+ results: { url: string; title: string; snippet: string; page_age?: string }[];
+ query: string;
}
export async function getWebSearchEmulationConfig(): Promise {
const { data } = await apiClient.get(
- '/admin/settings/web-search-emulation'
- )
- return data
+ "/admin/settings/web-search-emulation",
+ );
+ return data;
}
export async function updateWebSearchEmulationConfig(
- config: WebSearchEmulationConfig
+ config: WebSearchEmulationConfig,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/web-search-emulation',
- config
- )
- return data
+ "/admin/settings/web-search-emulation",
+ config,
+ );
+ return data;
}
export async function testWebSearchEmulation(
- query: string
+ query: string,
): Promise {
const { data } = await apiClient.post(
- '/admin/settings/web-search-emulation/test',
- { query }
- )
- return data
+ "/admin/settings/web-search-emulation/test",
+ { query },
+ );
+ return data;
}
-export async function resetWebSearchUsage(
- payload: { provider_type: string }
-): Promise {
- await apiClient.post('/admin/settings/web-search-emulation/reset-usage', payload)
+export async function resetWebSearchUsage(payload: {
+ provider_type: string;
+}): Promise {
+ await apiClient.post(
+ "/admin/settings/web-search-emulation/reset-usage",
+ payload,
+ );
}
export const settingsAPI = {
@@ -576,7 +1033,7 @@ export const settingsAPI = {
getWebSearchEmulationConfig,
updateWebSearchEmulationConfig,
testWebSearchEmulation,
- resetWebSearchUsage
-}
+ resetWebSearchUsage,
+};
-export default settingsAPI
+export default settingsAPI;
diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts
index 39cb1dfa..3c75a6c4 100644
--- a/frontend/src/api/admin/users.ts
+++ b/frontend/src/api/admin/users.ts
@@ -6,6 +6,44 @@
import { apiClient } from '../client'
import type { AdminUser, UpdateUserRequest, PaginatedResponse, ApiKey } from '@/types'
+export interface AdminBindAuthIdentityChannelRequest {
+ channel: string
+ channel_app_id: string
+ channel_subject: string
+ metadata?: Record | null
+}
+
+export interface AdminBindAuthIdentityRequest {
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ issuer?: string | null
+ metadata?: Record | null
+ channel?: AdminBindAuthIdentityChannelRequest
+}
+
+export interface AdminBoundAuthIdentityChannel {
+ channel: string
+ channel_app_id: string
+ channel_subject: string
+ metadata: Record | null
+ created_at: string
+ updated_at: string
+}
+
+export interface AdminBoundAuthIdentity {
+ user_id: number
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ verified_at?: string | null
+ issuer?: string | null
+ metadata: Record | null
+ created_at: string
+ updated_at: string
+ channel?: AdminBoundAuthIdentityChannel | null
+}
+
/**
* List all users with pagination
* @param page - Page number (default: 1)
@@ -248,6 +286,17 @@ export async function replaceGroup(
return data
}
+export async function bindUserAuthIdentity(
+ userId: number,
+ input: AdminBindAuthIdentityRequest
+): Promise {
+ const { data } = await apiClient.post(
+ `/admin/users/${userId}/auth-identities`,
+ input
+ )
+ return data
+}
+
export const usersAPI = {
list,
getById,
@@ -260,7 +309,8 @@ export const usersAPI = {
getUserApiKeys,
getUserUsageStats,
getUserBalanceHistory,
- replaceGroup
+ replaceGroup,
+ bindUserAuthIdentity
}
export default usersAPI
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index 837c4f4c..bb990fc4 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -186,6 +186,108 @@ export interface RefreshTokenResponse {
token_type: string
}
+export interface OAuthTokenResponse {
+ access_token: string
+ refresh_token?: string
+ expires_in?: number
+ token_type?: string
+}
+
+export interface PendingOAuthBindLoginResponse extends Partial {
+ auth_result?: string
+ redirect?: string
+ error?: string
+ requires_2fa?: boolean
+ temp_token?: string
+ user_email_masked?: string
+ adoption_required?: boolean
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}
+
+export type PendingOAuthExchangeResponse = PendingOAuthBindLoginResponse
+
+export interface PendingOAuthCreateAccountResponse extends OAuthTokenResponse {
+ auth_result?: string
+}
+
+export interface PendingOAuthSendVerifyCodeResponse extends SendVerifyCodeResponse {
+ auth_result?: string
+ provider?: string
+ redirect?: string
+}
+
+export type OAuthCompletionKind = 'login' | 'bind'
+
+export interface OAuthAdoptionDecision {
+ adoptDisplayName?: boolean
+ adoptAvatar?: boolean
+}
+
+function serializeOAuthAdoptionDecision(
+ decision?: OAuthAdoptionDecision
+): Record {
+ const payload: Record = {}
+
+ if (typeof decision?.adoptDisplayName === 'boolean') {
+ payload.adopt_display_name = decision.adoptDisplayName
+ }
+ if (typeof decision?.adoptAvatar === 'boolean') {
+ payload.adopt_avatar = decision.adoptAvatar
+ }
+
+ return payload
+}
+
+export function isOAuthLoginCompletion(
+ completion: Partial
+): completion is OAuthTokenResponse {
+ return typeof completion.access_token === 'string' && completion.access_token.trim().length > 0
+}
+
+export function getOAuthCompletionKind(
+ completion: Partial
+): OAuthCompletionKind {
+ return isOAuthLoginCompletion(completion) ? 'login' : 'bind'
+}
+
+export function getPendingOAuthBindLoginKind(
+ completion: PendingOAuthBindLoginResponse
+): OAuthCompletionKind {
+ return getOAuthCompletionKind(completion)
+}
+
+export function isPendingOAuthCreateAccountRequired(
+ completion: Pick
+): boolean {
+ return completion.error === 'invitation_required'
+}
+
+export function hasPendingOAuthSuggestedProfile(
+ completion: Pick<
+ PendingOAuthBindLoginResponse,
+ 'suggested_display_name' | 'suggested_avatar_url'
+ >
+): boolean {
+ return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
+}
+
+export function persistOAuthTokenContext(tokens: Partial): void {
+ if (tokens.refresh_token) {
+ setRefreshToken(tokens.refresh_token)
+ }
+ if (tokens.expires_in) {
+ setTokenExpiresAt(tokens.expires_in)
+ }
+}
+
+export async function prepareOAuthBindAccessTokenCookie(): Promise {
+ if (!getAuthToken()) {
+ return
+ }
+ await apiClient.post('/auth/oauth/bind-token')
+}
+
/**
* Refresh the access token using the refresh token
* @returns New token pair
@@ -234,6 +336,116 @@ export async function getPublicSettings(): Promise {
return data
}
+export type WeChatOAuthMode = 'open' | 'mp'
+export type WeChatOAuthUnavailableReason =
+ | 'not_configured'
+ | 'capability_unknown'
+ | 'external_browser_required'
+ | 'wechat_browser_required'
+ | 'native_app_required'
+
+export interface ResolvedWeChatOAuthStart {
+ mode: WeChatOAuthMode | null
+ openEnabled: boolean
+ mpEnabled: boolean
+ mobileEnabled: boolean
+ isWeChatBrowser: boolean
+ unavailableReason: WeChatOAuthUnavailableReason | null
+}
+
+export type WeChatOAuthPublicSettings = {
+ wechat_oauth_enabled?: boolean
+ wechat_oauth_open_enabled?: boolean
+ wechat_oauth_mp_enabled?: boolean
+ wechat_oauth_mobile_enabled?: boolean
+}
+
+export function isWeChatWebOAuthEnabled(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+): boolean {
+ const legacyEnabled = settings?.wechat_oauth_enabled ?? false
+ const hasExplicitCapabilities =
+ typeof settings?.wechat_oauth_open_enabled === 'boolean' ||
+ typeof settings?.wechat_oauth_mp_enabled === 'boolean'
+
+ if (!hasExplicitCapabilities) {
+ return legacyEnabled
+ }
+
+ return settings?.wechat_oauth_open_enabled === true || settings?.wechat_oauth_mp_enabled === true
+}
+
+export function hasExplicitWeChatOAuthCapabilities(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+): settings is WeChatOAuthPublicSettings & {
+ wechat_oauth_open_enabled: boolean
+ wechat_oauth_mp_enabled: boolean
+} {
+ return typeof settings?.wechat_oauth_open_enabled === 'boolean'
+ && typeof settings?.wechat_oauth_mp_enabled === 'boolean'
+}
+
+export function resolveWeChatOAuthStart(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+ userAgent?: string
+): ResolvedWeChatOAuthStart {
+ const normalizedUserAgent = (userAgent
+ ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '')
+ ?? '').trim()
+ const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent)
+ const legacyEnabled = settings?.wechat_oauth_enabled ?? false
+ const openEnabled = typeof settings?.wechat_oauth_open_enabled === 'boolean'
+ ? settings.wechat_oauth_open_enabled
+ : legacyEnabled
+ const mpEnabled = typeof settings?.wechat_oauth_mp_enabled === 'boolean'
+ ? settings.wechat_oauth_mp_enabled
+ : legacyEnabled
+ const mobileEnabled = typeof settings?.wechat_oauth_mobile_enabled === 'boolean'
+ ? settings.wechat_oauth_mobile_enabled
+ : false
+
+ if (isWeChatBrowser) {
+ if (mpEnabled) {
+ return { mode: 'mp', openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: null }
+ }
+ if (openEnabled) {
+ return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'external_browser_required' }
+ }
+ return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'not_configured' }
+ }
+
+ if (openEnabled) {
+ return { mode: 'open', openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: null }
+ }
+ if (mpEnabled) {
+ return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'wechat_browser_required' }
+ }
+ return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'not_configured' }
+}
+
+export function resolveWeChatOAuthStartStrict(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+ userAgent?: string,
+): ResolvedWeChatOAuthStart {
+ const normalizedUserAgent = (userAgent
+ ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '')
+ ?? '').trim()
+ const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent)
+
+ if (!hasExplicitWeChatOAuthCapabilities(settings)) {
+ return {
+ mode: null,
+ openEnabled: false,
+ mpEnabled: false,
+ mobileEnabled: false,
+ isWeChatBrowser,
+ unavailableReason: 'capability_unknown',
+ }
+ }
+
+ return resolveWeChatOAuthStart(settings, normalizedUserAgent)
+}
+
/**
* Send verification code to email
* @param request - Email and optional Turnstile token
@@ -246,6 +458,16 @@ export async function sendVerifyCode(
return data
}
+export async function sendPendingOAuthVerifyCode(
+ request: SendVerifyCodeRequest
+): Promise {
+ const { data } = await apiClient.post(
+ '/auth/oauth/pending/send-verify-code',
+ request
+ )
+ return data
+}
+
/**
* Validate promo code response
*/
@@ -337,48 +559,96 @@ export async function resetPassword(request: ResetPasswordRequest): Promise {
- const { data } = await apiClient.post<{
- access_token: string
- refresh_token: string
- expires_in: number
- token_type: string
- }>('/auth/oauth/linuxdo/complete-registration', {
- pending_oauth_token: pendingOAuthToken,
- invitation_code: invitationCode
- })
- return data
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingLinuxDoOAuthAccount(invitationCode, decision, affiliateCode)
}
/**
* Complete OIDC OAuth registration by supplying an invitation code
- * @param pendingOAuthToken - Short-lived JWT from the OAuth callback
* @param invitationCode - Invitation code entered by the user
* @returns Token pair on success
*/
export async function completeOIDCOAuthRegistration(
- pendingOAuthToken: string,
- invitationCode: string
-): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> {
- const { data } = await apiClient.post<{
- access_token: string
- refresh_token: string
- expires_in: number
- token_type: string
- }>('/auth/oauth/oidc/complete-registration', {
- pending_oauth_token: pendingOAuthToken,
- invitation_code: invitationCode
- })
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingOIDCOAuthAccount(invitationCode, decision, affiliateCode)
+}
+
+export async function completeWeChatOAuthRegistration(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingWeChatOAuthAccount(invitationCode, decision, affiliateCode)
+}
+
+async function createPendingOAuthAccount(
+ provider: 'linuxdo' | 'oidc' | 'wechat',
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ const normalizedAffiliateCode = affiliateCode?.trim()
+ const { data } = await apiClient.post(
+ `/auth/oauth/${provider}/complete-registration`,
+ {
+ invitation_code: invitationCode,
+ ...(normalizedAffiliateCode ? { aff_code: normalizedAffiliateCode } : {}),
+ ...serializeOAuthAdoptionDecision(decision)
+ }
+ )
return data
}
+export async function createPendingLinuxDoOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingOAuthAccount('linuxdo', invitationCode, decision, affiliateCode)
+}
+
+export async function createPendingOIDCOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingOAuthAccount('oidc', invitationCode, decision, affiliateCode)
+}
+
+export async function createPendingWeChatOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingOAuthAccount('wechat', invitationCode, decision, affiliateCode)
+}
+
+export async function completePendingOAuthBindLogin(
+ decision?: OAuthAdoptionDecision
+): Promise {
+ const { data } = await apiClient.post(
+ '/auth/oauth/pending/exchange',
+ serializeOAuthAdoptionDecision(decision)
+ )
+ return data
+}
+
+export async function exchangePendingOAuthCompletion(
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return completePendingOAuthBindLogin(decision)
+}
+
export const authAPI = {
login,
login2FA,
@@ -396,14 +666,24 @@ export const authAPI = {
clearAuthToken,
getPublicSettings,
sendVerifyCode,
+ sendPendingOAuthVerifyCode,
validatePromoCode,
validateInvitationCode,
forgotPassword,
resetPassword,
refreshToken,
revokeAllSessions,
+ getPendingOAuthBindLoginKind,
+ isPendingOAuthCreateAccountRequired,
+ hasPendingOAuthSuggestedProfile,
+ completePendingOAuthBindLogin,
+ createPendingLinuxDoOAuthAccount,
+ createPendingOIDCOAuthAccount,
+ createPendingWeChatOAuthAccount,
+ exchangePendingOAuthCompletion,
completeLinuxDoOAuthRegistration,
- completeOIDCOAuthRegistration
+ completeOIDCOAuthRegistration,
+ completeWeChatOAuthRegistration
}
export default authAPI
diff --git a/frontend/src/api/channelMonitor.ts b/frontend/src/api/channelMonitor.ts
new file mode 100644
index 00000000..38dd0c99
--- /dev/null
+++ b/frontend/src/api/channelMonitor.ts
@@ -0,0 +1,83 @@
+/**
+ * User-facing Channel Monitor API endpoints
+ * Read-only views for end users to inspect channel availability/status.
+ */
+
+import { apiClient } from './client'
+import type { Provider, MonitorStatus } from './admin/channelMonitor'
+
+export type { Provider, MonitorStatus } from './admin/channelMonitor'
+
+export interface UserMonitorExtraModel {
+ model: string
+ status: MonitorStatus
+ latency_ms: number | null
+}
+
+export interface MonitorTimelinePoint {
+ status: MonitorStatus
+ latency_ms: number | null
+ ping_latency_ms: number | null
+ checked_at: string
+}
+
+export interface UserMonitorView {
+ id: number
+ name: string
+ provider: Provider
+ group_name: string
+ primary_model: string
+ primary_status: MonitorStatus
+ primary_latency_ms: number | null
+ primary_ping_latency_ms: number | null
+ availability_7d: number
+ extra_models: UserMonitorExtraModel[]
+ timeline: MonitorTimelinePoint[]
+}
+
+export interface UserMonitorListResponse {
+ items: UserMonitorView[]
+}
+
+export interface UserMonitorModelDetail {
+ model: string
+ latest_status: MonitorStatus
+ latest_latency_ms: number | null
+ availability_7d: number
+ availability_15d: number
+ availability_30d: number
+ avg_latency_7d_ms: number | null
+}
+
+export interface UserMonitorDetail {
+ id: number
+ name: string
+ provider: Provider
+ group_name: string
+ models: UserMonitorModelDetail[]
+}
+
+/**
+ * List all monitor views available to the current user.
+ */
+export async function list(options?: { signal?: AbortSignal }): Promise {
+ const { data } = await apiClient.get('/channel-monitors', {
+ signal: options?.signal,
+ })
+ return data
+}
+
+/**
+ * Get detailed status (multi-window availability + latency) for a single monitor.
+ */
+export async function status(id: number): Promise {
+ const { data } = await apiClient.get(`/channel-monitors/${id}/status`)
+ return data
+}
+
+export const channelMonitorUserAPI = {
+ list,
+ status,
+}
+
+export default channelMonitorUserAPI
diff --git a/frontend/src/api/channels.ts b/frontend/src/api/channels.ts
new file mode 100644
index 00000000..8962af2c
--- /dev/null
+++ b/frontend/src/api/channels.ts
@@ -0,0 +1,76 @@
+/**
+ * User Channels API endpoints (non-admin)
+ * 用户侧「可用渠道」聚合查询:渠道 + 用户可访问的分组 + 支持模型(含定价)。
+ */
+
+import { apiClient } from './client'
+import type { BillingMode } from '@/constants/channel'
+
+export interface UserAvailableGroup {
+ id: number
+ name: string
+ platform: string
+ /** 'standard' | 'subscription' — 订阅分组视觉加深,和 API 密钥页保持一致。 */
+ subscription_type: string
+ /** 分组默认倍率。用户专属倍率(若有)通过 /groups/rates 获取后在前端 join。 */
+ rate_multiplier: number
+ /** true = 专属分组(小范围授权);false = 公开分组。 */
+ is_exclusive: boolean
+}
+
+export interface UserPricingInterval {
+ min_tokens: number
+ max_tokens: number | null
+ tier_label?: string
+ input_price: number | null
+ output_price: number | null
+ cache_write_price: number | null
+ cache_read_price: number | null
+ per_request_price: number | null
+}
+
+export interface UserSupportedModelPricing {
+ billing_mode: BillingMode
+ input_price: number | null
+ output_price: number | null
+ cache_write_price: number | null
+ cache_read_price: number | null
+ image_output_price: number | null
+ per_request_price: number | null
+ intervals: UserPricingInterval[]
+}
+
+export interface UserSupportedModel {
+ name: string
+ platform: string
+ pricing: UserSupportedModelPricing | null
+}
+
+/**
+ * 渠道下单个平台的子视图:用户可访问的分组 + 该平台支持的模型。
+ * 后端把一个渠道按平台聚合成 sections,前端可以把渠道名作为 row-group
+ * 一次渲染,后面按 sections 顺序用 rowspan 铺开。
+ */
+export interface UserChannelPlatformSection {
+ platform: string
+ groups: UserAvailableGroup[]
+ supported_models: UserSupportedModel[]
+}
+
+export interface UserAvailableChannel {
+ name: string
+ description: string
+ platforms: UserChannelPlatformSection[]
+}
+
+/** 列出当前用户可见的「可用渠道」(与 /groups/available 保持一致,返回平数组)。 */
+export async function getAvailable(options?: { signal?: AbortSignal }): Promise {
+ const { data } = await apiClient.get('/channels/available', {
+ signal: options?.signal
+ })
+ return data
+}
+
+export const userChannelsAPI = { getAvailable }
+
+export default userChannelsAPI
diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts
index 8a586902..54ea4520 100644
--- a/frontend/src/api/client.ts
+++ b/frontend/src/api/client.ts
@@ -13,6 +13,7 @@ const API_BASE_URL = import.meta.env.VITE_API_BASE_URL || '/api/v1'
export const apiClient: AxiosInstance = axios.create({
baseURL: API_BASE_URL,
+ withCredentials: true,
timeout: 30000,
headers: {
'Content-Type': 'application/json'
diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts
index 6b3ef174..6702468d 100644
--- a/frontend/src/api/index.ts
+++ b/frontend/src/api/index.ts
@@ -16,8 +16,10 @@ export { userAPI } from './user'
export { redeemAPI, type RedeemHistoryItem } from './redeem'
export { paymentAPI } from './payment'
export { userGroupsAPI } from './groups'
+export { userChannelsAPI } from './channels'
export { totpAPI } from './totp'
export { default as announcementsAPI } from './announcements'
+export { channelMonitorUserAPI } from './channelMonitor'
// Admin APIs
export { adminAPI } from './admin'
diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts
index 5cedb107..92b0ec90 100644
--- a/frontend/src/api/payment.ts
+++ b/frontend/src/api/payment.ts
@@ -67,11 +67,16 @@ export const paymentAPI = {
return apiClient.post('/payment/orders/verify', { out_trade_no: outTradeNo })
},
- /** Verify order payment status without auth (public endpoint for result page) */
+ /** Legacy-compatible public order lookup by out_trade_no */
verifyOrderPublic(outTradeNo: string) {
return apiClient.post('/payment/public/orders/verify', { out_trade_no: outTradeNo })
},
+ /** Resolve an order from a signed resume token without auth */
+ resolveOrderPublicByResumeToken(resumeToken: string) {
+ return apiClient.post('/payment/public/orders/resolve', { resume_token: resumeToken })
+ },
+
/** Request a refund for a completed order */
requestRefund(id: number, data: { reason: string }) {
return apiClient.post(`/payment/orders/${id}/refund-request`, data)
diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts
index cd648270..da7a91eb 100644
--- a/frontend/src/api/user.ts
+++ b/frontend/src/api/user.ts
@@ -4,7 +4,19 @@
*/
import { apiClient } from './client'
-import type { User, ChangePasswordRequest, NotifyEmailEntry } from '@/types'
+import {
+ resolveWeChatOAuthStartStrict,
+ prepareOAuthBindAccessTokenCookie,
+ type WeChatOAuthPublicSettings,
+} from './auth'
+import type {
+ User,
+ ChangePasswordRequest,
+ NotifyEmailEntry,
+ UserAuthProvider,
+ UserAffiliateDetail,
+ AffiliateTransferResponse
+} from '@/types'
/**
* Get current user profile
@@ -22,6 +34,7 @@ export async function getProfile(): Promise {
*/
export async function updateProfile(profile: {
username?: string
+ avatar_url?: string | null
balance_notify_enabled?: boolean
balance_notify_threshold?: number | null
balance_notify_extra_emails?: NotifyEmailEntry[]
@@ -83,6 +96,95 @@ export async function toggleNotifyEmail(email: string, disabled: boolean): Promi
return data
}
+export async function sendEmailBindingCode(email: string): Promise {
+ await apiClient.post('/user/account-bindings/email/send-code', { email })
+}
+
+export async function bindEmailIdentity(payload: {
+ email: string
+ verify_code: string
+ password: string
+}): Promise {
+ const { data } = await apiClient.post('/user/account-bindings/email', payload)
+ return data
+}
+
+export async function unbindAuthIdentity(provider: BindableOAuthProvider): Promise {
+ const { data } = await apiClient.delete(`/user/account-bindings/${provider}`)
+ return data
+}
+
+export type BindableOAuthProvider = Exclude
+
+interface BuildOAuthBindingStartURLOptions {
+ redirectTo?: string
+ wechatOAuthSettings?: WeChatOAuthPublicSettings | null
+}
+
+export function resolveWeChatOAuthMode(): 'open' | 'mp' {
+ if (typeof navigator === 'undefined') {
+ return 'open'
+ }
+ return /MicroMessenger/i.test(navigator.userAgent) ? 'mp' : 'open'
+}
+
+function resolveWeChatOAuthBindingMode(
+ settings?: WeChatOAuthPublicSettings | null
+): 'open' | 'mp' | null {
+ if (settings) {
+ return resolveWeChatOAuthStartStrict(settings).mode
+ }
+ return resolveWeChatOAuthMode()
+}
+
+export function buildOAuthBindingStartURL(
+ provider: BindableOAuthProvider,
+ options: BuildOAuthBindingStartURLOptions = {}
+): string | null {
+ const redirectTo = options.redirectTo?.trim() || '/profile'
+ const apiBase = (import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1'
+ const normalized = apiBase.replace(/\/$/, '')
+ const params = new URLSearchParams({
+ redirect: redirectTo,
+ intent: 'bind_current_user'
+ })
+
+ if (provider === 'wechat') {
+ const mode = resolveWeChatOAuthBindingMode(options.wechatOAuthSettings)
+ if (!mode) {
+ return null
+ }
+ params.set('mode', mode)
+ }
+
+ return `${normalized}/auth/oauth/${provider}/bind/start?${params.toString()}`
+}
+
+export async function startOAuthBinding(
+ provider: BindableOAuthProvider,
+ options: BuildOAuthBindingStartURLOptions = {}
+): Promise {
+ if (typeof window === 'undefined') {
+ return
+ }
+ const startURL = buildOAuthBindingStartURL(provider, options)
+ if (!startURL) {
+ return
+ }
+ await prepareOAuthBindAccessTokenCookie()
+ window.location.href = startURL
+}
+
+export async function getAffiliateDetail(): Promise {
+ const { data } = await apiClient.get('/user/aff')
+ return data
+}
+
+export async function transferAffiliateQuota(): Promise {
+ const { data } = await apiClient.post('/user/aff/transfer')
+ return data
+}
+
export const userAPI = {
getProfile,
updateProfile,
@@ -90,7 +192,14 @@ export const userAPI = {
sendNotifyEmailCode,
verifyNotifyEmail,
removeNotifyEmail,
- toggleNotifyEmail
+ toggleNotifyEmail,
+ sendEmailBindingCode,
+ bindEmailIdentity,
+ unbindAuthIdentity,
+ buildOAuthBindingStartURL,
+ startOAuthBinding,
+ getAffiliateDetail,
+ transferAffiliateQuota
}
export default userAPI
diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue
index fc2f7d0c..dd38a49f 100644
--- a/frontend/src/components/account/AccountStatusIndicator.vue
+++ b/frontend/src/components/account/AccountStatusIndicator.vue
@@ -284,6 +284,16 @@ const hasError = computed(() => {
return props.account.status === 'error'
})
+const isQuotaExceeded = computed(() => {
+ const exceeded = (used?: number | null, limit?: number | null) =>
+ typeof limit === 'number' && limit > 0 && typeof used === 'number' && used >= limit
+ return (
+ exceeded(props.account.quota_used, props.account.quota_limit) ||
+ exceeded(props.account.quota_daily_used, props.account.quota_daily_limit) ||
+ exceeded(props.account.quota_weekly_used, props.account.quota_weekly_limit)
+ )
+})
+
// Computed: countdown text for rate limit (429)
const rateLimitCountdown = computed(() => {
return formatCountdown(props.account.rate_limit_reset_at)
@@ -307,19 +317,16 @@ const statusClass = computed(() => {
if (isTempUnschedulable.value) {
return 'badge-warning'
}
+ if (props.account.status !== 'active') {
+ return props.account.status === 'error' ? 'badge-danger' : 'badge-gray'
+ }
+ if (isQuotaExceeded.value) {
+ return 'badge-warning'
+ }
if (!props.account.schedulable) {
return 'badge-gray'
}
- switch (props.account.status) {
- case 'active':
- return 'badge-success'
- case 'inactive':
- return 'badge-gray'
- case 'error':
- return 'badge-danger'
- default:
- return 'badge-gray'
- }
+ return 'badge-success'
})
// Computed: status text
@@ -330,6 +337,12 @@ const statusText = computed(() => {
if (isTempUnschedulable.value) {
return t('admin.accounts.status.tempUnschedulable')
}
+ if (props.account.status !== 'active') {
+ return t(`admin.accounts.status.${props.account.status}`)
+ }
+ if (isQuotaExceeded.value) {
+ return t('admin.accounts.status.quotaExceeded')
+ }
if (!props.account.schedulable) {
return t('admin.accounts.status.paused')
}
diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue
index 67409a7c..236f85f1 100644
--- a/frontend/src/components/account/AccountTestModal.vue
+++ b/frontend/src/components/account/AccountTestModal.vue
@@ -55,12 +55,23 @@
/>
-
+
+
+ {{ t('admin.accounts.openai.testMode') }}
+
+
+
+
+
@@ -122,25 +133,49 @@
- {{ t('admin.accounts.geminiImagePreview') }}
+ {{ t('admin.accounts.imagePreview') }}
-
+
+
+
+
+
+
+
+
+
+
+
+
@@ -152,8 +187,8 @@
{{
- supportsGeminiImageTest
- ? t('admin.accounts.geminiImageTestMode')
+ supportsImageTest
+ ? t('admin.accounts.imageTestMode')
: t('admin.accounts.testPrompt')
}}
@@ -250,6 +285,13 @@ const testPrompt = ref('')
const loadingModels = ref(false)
let abortController: AbortController | null = null
const generatedImages = ref
([])
+const testMode = ref<'default' | 'compact'>('default')
+const isOpenAIAccount = computed(() => props.account?.platform === 'openai')
+const openAITestModeOptions = computed(() => [
+ { value: 'default', label: t('admin.accounts.openai.testModeDefault') },
+ { value: 'compact', label: t('admin.accounts.openai.testModeCompact') }
+])
+const previewImageUrl = ref('')
const prioritizedGeminiModels = ['gemini-3.1-flash-image', 'gemini-2.5-flash-image', 'gemini-2.5-flash', 'gemini-2.5-pro', 'gemini-3-flash-preview', 'gemini-3-pro-preview', 'gemini-2.0-flash']
const supportsGeminiImageTest = computed(() => {
const modelID = selectedModelId.value.toLowerCase()
@@ -258,6 +300,14 @@ const supportsGeminiImageTest = computed(() => {
return props.account?.platform === 'gemini' || (props.account?.platform === 'antigravity' && props.account?.type === 'apikey')
})
+const supportsOpenAIImageTest = computed(() => {
+ const modelID = selectedModelId.value.toLowerCase()
+ if (!modelID.startsWith('gpt-image-')) return false
+ return props.account?.platform === 'openai'
+})
+
+const supportsImageTest = computed(() => supportsGeminiImageTest.value || supportsOpenAIImageTest.value)
+
const sortTestModels = (models: ClaudeModel[]) => {
const priorityMap = new Map(prioritizedGeminiModels.map((id, index) => [id, index]))
@@ -275,6 +325,7 @@ watch(
async (newVal) => {
if (newVal && props.account) {
testPrompt.value = ''
+ testMode.value = 'default'
resetState()
await loadAvailableModels()
} else {
@@ -284,8 +335,8 @@ watch(
)
watch(selectedModelId, () => {
- if (supportsGeminiImageTest.value && !testPrompt.value.trim()) {
- testPrompt.value = t('admin.accounts.geminiImagePromptDefault')
+ if (supportsImageTest.value && !testPrompt.value.trim()) {
+ testPrompt.value = t('admin.accounts.imagePromptDefault')
}
})
@@ -325,6 +376,7 @@ const resetState = () => {
streamingContent.value = ''
errorMessage.value = ''
generatedImages.value = []
+ previewImageUrl.value = ''
}
const handleClose = () => {
@@ -376,9 +428,10 @@ const startTest = async () => {
'Content-Type': 'application/json'
},
body: JSON.stringify({
- model_id: selectedModelId.value,
- prompt: supportsGeminiImageTest.value ? testPrompt.value.trim() : ''
- }),
+ model_id: selectedModelId.value,
+ prompt: supportsImageTest.value ? testPrompt.value.trim() : '',
+ mode: isOpenAIAccount.value ? testMode.value : 'default'
+ }),
signal: abortController.signal
})
@@ -444,8 +497,8 @@ const handleEvent = (event: {
addLine(t('admin.accounts.usingModel', { model: event.model }), 'text-cyan-400')
}
addLine(
- supportsGeminiImageTest.value
- ? t('admin.accounts.sendingGeminiImageRequest')
+ supportsImageTest.value
+ ? t('admin.accounts.sendingImageRequest')
: t('admin.accounts.sendingTestMessage'),
'text-gray-400'
)
@@ -466,7 +519,7 @@ const handleEvent = (event: {
url: event.image_url,
mimeType: event.mime_type
})
- addLine(t('admin.accounts.geminiImageReceived', { count: generatedImages.value.length }), 'text-purple-300')
+ addLine(t('admin.accounts.imageReceived', { count: generatedImages.value.length }), 'text-purple-300')
}
break
@@ -500,3 +553,14 @@ const copyOutput = () => {
copyToClipboard(text, t('admin.accounts.outputCopied'))
}
+
+
diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue
index 1c023fb3..2c04e673 100644
--- a/frontend/src/components/account/AccountUsageCell.vue
+++ b/frontend/src/components/account/AccountUsageCell.vue
@@ -332,6 +332,37 @@
+
+
+
+ {{ formatKeyRequests }} req
+
+
+ {{ formatKeyTokens }}
+
+
+ A ${{ formatKeyCost }}
+
+
+ U ${{ formatKeyUserCost }}
+
+
+
+
@@ -512,6 +543,10 @@ const shouldFetchUsage = computed(() => {
return false
})
+const showGeminiTodayStats = computed(() => {
+ return props.account.platform === 'gemini' && props.account.type === 'service_account'
+})
+
const geminiUsageAvailable = computed(() => {
return (
!!usageInfo.value?.gemini_shared_daily ||
diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue
index 13c30cf9..05016a6d 100644
--- a/frontend/src/components/account/BulkEditAccountModal.vue
+++ b/frontend/src/components/account/BulkEditAccountModal.vue
@@ -17,7 +17,7 @@
d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/>
- {{ t('admin.accounts.bulkEdit.selectionInfo', { count: accountIds.length }) }}
+ {{ t('admin.accounts.bulkEdit.selectionInfo', { count: targetMode === 'filtered' ? targetPreviewCount : accountIds.length }) }}
@@ -27,7 +27,7 @@
- {{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: selectedPlatforms.join(', ') }) }}
+ {{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: targetSelectedPlatforms.join(', ') }) }}
@@ -227,7 +227,7 @@
@@ -698,6 +698,87 @@
+
+
+
+
+ {{ t('admin.accounts.openai.codexCLIOnly') }}
+
+
+
+
+
+ {{ t('admin.accounts.openai.codexCLIOnlyDesc') }}
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.openai.wsMode') }}
+
+
+
+
+
+ {{ t('admin.accounts.openai.wsModeDesc') }}
+
+
+ {{ t(openAIAPIKeyWSModeConcurrencyHintKey) }}
+
+
+
+
+
@@ -933,6 +1014,13 @@ interface Props {
accountIds: number[]
selectedPlatforms: AccountPlatform[]
selectedTypes: AccountType[]
+ target?: {
+ mode: 'selected' | 'filtered'
+ filters?: Record
+ previewCount?: number
+ selectedPlatforms?: AccountPlatform[]
+ selectedTypes?: AccountType[]
+ }
proxies: ProxyConfig[]
groups: AdminGroup[]
}
@@ -947,40 +1035,53 @@ const { t } = useI18n()
const appStore = useAppStore()
// Platform awareness
-const isMixedPlatform = computed(() => props.selectedPlatforms.length > 1)
+const targetMode = computed(() => props.target?.mode ?? 'selected')
+const targetPreviewCount = computed(() => props.target?.previewCount ?? props.accountIds.length)
+const targetSelectedPlatforms = computed(() => props.target?.selectedPlatforms ?? props.selectedPlatforms)
+const targetSelectedTypes = computed(() => props.target?.selectedTypes ?? props.selectedTypes)
+const isMixedPlatform = computed(() => targetSelectedPlatforms.value.length > 1)
const allOpenAIPassthroughCapable = computed(() => {
return (
- props.selectedPlatforms.length === 1 &&
- props.selectedPlatforms[0] === 'openai' &&
- props.selectedTypes.length > 0 &&
- props.selectedTypes.every(t => t === 'oauth' || t === 'apikey')
+ targetSelectedPlatforms.value.length === 1 &&
+ targetSelectedPlatforms.value[0] === 'openai' &&
+ targetSelectedTypes.value.length > 0 &&
+ targetSelectedTypes.value.every(t => t === 'oauth' || t === 'apikey')
)
})
const allOpenAIOAuth = computed(() => {
return (
- props.selectedPlatforms.length === 1 &&
- props.selectedPlatforms[0] === 'openai' &&
- props.selectedTypes.length > 0 &&
- props.selectedTypes.every(t => t === 'oauth')
+ targetSelectedPlatforms.value.length === 1 &&
+ targetSelectedPlatforms.value[0] === 'openai' &&
+ targetSelectedTypes.value.length > 0 &&
+ targetSelectedTypes.value.every(t => t === 'oauth')
+ )
+})
+
+const allOpenAIAPIKey = computed(() => {
+ return (
+ targetSelectedPlatforms.value.length === 1 &&
+ targetSelectedPlatforms.value[0] === 'openai' &&
+ targetSelectedTypes.value.length > 0 &&
+ targetSelectedTypes.value.every(t => t === 'apikey')
)
})
// 是否全部为 Anthropic OAuth/SetupToken(RPM 配置仅在此条件下显示)
const allAnthropicOAuthOrSetupToken = computed(() => {
return (
- props.selectedPlatforms.length === 1 &&
- props.selectedPlatforms[0] === 'anthropic' &&
- props.selectedTypes.every(t => t === 'oauth' || t === 'setup-token')
+ targetSelectedPlatforms.value.length === 1 &&
+ targetSelectedPlatforms.value[0] === 'anthropic' &&
+ targetSelectedTypes.value.every(t => t === 'oauth' || t === 'setup-token')
)
})
const filteredPresets = computed(() => {
- if (props.selectedPlatforms.length === 0) return []
+ if (targetSelectedPlatforms.value.length === 0) return []
const dedupedPresets = new Map[number]>()
- for (const platform of props.selectedPlatforms) {
+ for (const platform of targetSelectedPlatforms.value) {
for (const preset of getPresetMappingsByPlatform(platform)) {
const key = `${preset.from}=>${preset.to}`
if (!dedupedPresets.has(key)) {
@@ -1012,6 +1113,8 @@ const enableStatus = ref(false)
const enableGroups = ref(false)
const enableOpenAIPassthrough = ref(false)
const enableOpenAIWSMode = ref(false)
+const enableOpenAIAPIKeyWSMode = ref(false)
+const enableCodexCLIOnly = ref(false)
const enableRpmLimit = ref(false)
// State - field values
@@ -1035,6 +1138,8 @@ const status = ref<'active' | 'inactive'>('active')
const groupIds = ref([])
const openaiPassthroughEnabled = ref(false)
const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
+const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
+const codexCLIOnlyEnabled = ref(false)
const rpmLimitEnabled = ref(false)
const bulkBaseRpm = ref(null)
const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered')
@@ -1076,6 +1181,9 @@ const openAIWSModeOptions = computed(() => [
const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiOAuthResponsesWebSocketV2Mode.value)
)
+const openAIAPIKeyWSModeConcurrencyHintKey = computed(() =>
+ resolveOpenAIWSModeConcurrencyHintKey(openaiAPIKeyResponsesWebSocketV2Mode.value)
+)
// Model mapping helpers
const addModelMapping = () => {
@@ -1254,6 +1362,19 @@ const buildUpdatePayload = (): Record | null => {
)
}
+ if (enableOpenAIAPIKeyWSMode.value) {
+ const extra = ensureExtra()
+ extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value
+ extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(
+ openaiAPIKeyResponsesWebSocketV2Mode.value
+ )
+ }
+
+ if (enableCodexCLIOnly.value) {
+ const extra = ensureExtra()
+ extra.codex_cli_only = codexCLIOnlyEnabled.value
+ }
+
// RPM limit settings (写入 extra 字段)
if (enableRpmLimit.value) {
const extra = ensureExtra()
@@ -1291,8 +1412,8 @@ const mixedChannelConfirmed = ref(false)
const canPreCheck = () =>
enableGroups.value &&
groupIds.value.length > 0 &&
- props.selectedPlatforms.length === 1 &&
- (props.selectedPlatforms[0] === 'antigravity' || props.selectedPlatforms[0] === 'anthropic')
+ targetSelectedPlatforms.value.length === 1 &&
+ (targetSelectedPlatforms.value[0] === 'antigravity' || targetSelectedPlatforms.value[0] === 'anthropic')
const handleClose = () => {
showMixedChannelWarning.value = false
@@ -1309,7 +1430,7 @@ const preCheckMixedChannelRisk = async (built: Record): Promise
try {
const result = await adminAPI.accounts.checkMixedChannelRisk({
- platform: props.selectedPlatforms[0],
+ platform: targetSelectedPlatforms.value[0],
group_ids: groupIds.value
})
if (!result.has_risk) return true
@@ -1325,7 +1446,7 @@ const preCheckMixedChannelRisk = async (built: Record): Promise
}
const handleSubmit = async () => {
- if (props.accountIds.length === 0) {
+ if (targetMode.value === 'selected' && props.accountIds.length === 0) {
appStore.showError(t('admin.accounts.bulkEdit.noSelection'))
return
}
@@ -1344,6 +1465,8 @@ const handleSubmit = async () => {
enableStatus.value ||
enableGroups.value ||
enableOpenAIWSMode.value ||
+ enableOpenAIAPIKeyWSMode.value ||
+ enableCodexCLIOnly.value ||
enableRpmLimit.value ||
userMsgQueueMode.value !== null
@@ -1373,7 +1496,12 @@ const submitBulkUpdate = async (baseUpdates: Record) => {
submitting.value = true
try {
- const res = await adminAPI.accounts.bulkUpdate(props.accountIds, updates)
+ const res = targetMode.value === 'filtered' && props.target?.filters
+ ? await adminAPI.accounts.bulkUpdate({
+ filters: props.target.filters,
+ ...updates
+ })
+ : await adminAPI.accounts.bulkUpdate(props.accountIds, updates)
const success = res.success || 0
const failed = res.failed || 0
@@ -1437,6 +1565,8 @@ watch(
enableGroups.value = false
enableOpenAIPassthrough.value = false
enableOpenAIWSMode.value = false
+ enableOpenAIAPIKeyWSMode.value = false
+ enableCodexCLIOnly.value = false
enableRpmLimit.value = false
// Reset all values
@@ -1456,6 +1586,8 @@ watch(
status.value = 'active'
groupIds.value = []
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
+ openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
+ codexCLIOnlyEnabled.value = false
rpmLimitEnabled.value = false
bulkBaseRpm.value = null
bulkRpmStrategy.value = 'tiered'
diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue
index 2130c9ab..d38c31c5 100644
--- a/frontend/src/components/account/CreateAccountModal.vue
+++ b/frontend/src/components/account/CreateAccountModal.vue
@@ -153,7 +153,7 @@
{{ t('admin.accounts.accountType') }}
-
+
+
+
+
+
+
+ Vertex
+ Service Account
+
+
+
+
+
+
+
{{ t('admin.accounts.vertexAnthropicHint') }}
@@ -302,6 +335,7 @@
{{ t('admin.accounts.types.responsesApi') }}
+
@@ -320,7 +354,7 @@
{{ t('admin.accounts.gemini.helpButton') }}
-
+
+
+
+
+
+
+
+
+ Vertex
+
+
+ Service Account
+
+
+
+
+
{{ t('admin.accounts.vertexGeminiHint') }}
+
+
{{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}
@@ -610,7 +681,7 @@
-
+
{{ t('admin.accounts.gemini.tier.label') }}
+
+
+
+
Service Account JSON
+
+
+
+
+
+
+ {{ vertexClientEmail ? t('admin.accounts.vertexSaJsonLoaded') : t('admin.accounts.vertexSaJsonDrop') }}
+
+
+ {{ vertexClientEmail ? t('admin.accounts.vertexSaJsonKeyHidden') : t('admin.accounts.vertexSaJsonDropHint') }}
+
+
+
+
+ {{ t('admin.accounts.vertexSaJsonSelectBtn') }}
+
+
+
+
Project ID: {{ vertexProjectId }}
+
Client Email: {{ vertexClientEmail }}
+
+
+
{{ t('admin.accounts.vertexSaJsonUploadHint') }}
+
+
+
+
+ Project ID
+
+
+
+
Location
+
+
+
+ {{ option.label }}
+
+
+
+
{{ t('admin.accounts.vertexLocationHint') }}
+
+
+
+
@@ -2449,6 +2610,45 @@
+
+
+
+
+
{{ t('admin.accounts.openai.compactMode') }}
+
+ {{ t('admin.accounts.openai.compactModeDesc') }}
+
+
+
+
+
+
+
+
{{ t('admin.accounts.openai.compactModelMapping') }}
+
{{ t('admin.accounts.openai.compactModelMappingDesc') }}
+
+
+ + {{ t('admin.accounts.addMapping') }}
+
+
+
+
@@ -2918,7 +3118,8 @@ import type {
AccountPlatform,
AccountType,
CheckMixedChannelResponse,
- CreateAccountRequest
+ CreateAccountRequest,
+ OpenAICompactMode
} from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
@@ -2931,6 +3132,7 @@ import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
+import { VERTEX_LOCATION_OPTIONS } from '@/constants/account'
import {
OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_OFF,
@@ -3045,7 +3247,7 @@ interface TempUnschedRuleForm {
// State
const step = ref(1)
const submitting = ref(false)
-const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock'>('oauth-based') // UI selection for account category
+const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'service_account'>('oauth-based') // UI selection for account category
const addMethod = ref
('oauth') // For oauth-based: 'oauth' or 'setup-token'
const apiKeyBaseUrl = ref('https://api.anthropic.com')
const apiKeyValue = ref('')
@@ -3059,6 +3261,7 @@ const editWeeklyResetDay = ref(null)
const editWeeklyResetHour = ref(null)
const editResetTimezone = ref(null)
const modelMappings = ref([])
+const openAICompactModelMappings = ref([])
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref([])
const DEFAULT_POOL_MODE_RETRY_COUNT = 3
@@ -3071,6 +3274,7 @@ const customErrorCodeInput = ref(null)
const interceptWarmupRequests = ref(false)
const autoPauseOnExpired = ref(true)
const openaiPassthroughEnabled = ref(false)
+const openAICompactMode = ref('auto')
const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
@@ -3109,13 +3313,25 @@ const bedrockSessionToken = ref('')
const bedrockRegion = ref('us-east-1')
const bedrockForceGlobal = ref(false)
const bedrockApiKeyValue = ref('')
+const vertexServiceAccountFileInput = ref(null)
+const vertexServiceAccountJson = ref('')
+const vertexProjectId = ref('')
+const vertexClientEmail = ref('')
+const vertexLocation = ref('global')
+const vertexServiceAccountDragActive = ref(false)
const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref([])
const getModelMappingKey = createStableObjectKeyResolver('create-model-mapping')
+const getOpenAICompactModelMappingKey = createStableObjectKeyResolver('create-openai-compact-model-mapping')
const getAntigravityModelMappingKey = createStableObjectKeyResolver('create-antigravity-model-mapping')
const getTempUnschedRuleKey = createStableObjectKeyResolver('create-temp-unsched-rule')
const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one')
const geminiAIStudioOAuthEnabled = ref(false)
+const openAICompactModeOptions = computed(() => [
+ { value: 'auto', label: t('admin.accounts.openai.compactModeAuto') },
+ { value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') },
+ { value: 'force_off', label: t('admin.accounts.openai.compactModeForceOff') }
+])
function buildAntigravityExtra(): Record | undefined {
const extra: Record = {}
@@ -3124,6 +3340,9 @@ function buildAntigravityExtra(): Record | undefined {
return Object.keys(extra).length > 0 ? extra : undefined
}
+const buildOpenAICompactModelMapping = () =>
+ buildModelMappingObject('mapping', [], openAICompactModelMappings.value)
+
const showMixedChannelWarning = ref(false)
const mixedChannelWarningDetails = ref<{ groupName: string; currentPlatform: string; otherPlatform: string } | null>(
null
@@ -3346,7 +3565,7 @@ watch(
// Sync form.type based on accountCategory, addMethod, and platform-specific type
watch(
- [accountCategory, addMethod, antigravityAccountType],
+ [accountCategory, addMethod, antigravityAccountType, () => form.platform],
([category, method, agType]) => {
// Antigravity upstream 类型(实际创建为 apikey)
if (form.platform === 'antigravity' && agType === 'upstream') {
@@ -3358,7 +3577,9 @@ watch(
form.type = 'bedrock' as AccountType
return
}
- if (category === 'oauth-based') {
+ if ((form.platform === 'gemini' || form.platform === 'anthropic') && category === 'service_account') {
+ form.type = 'service_account' as AccountType
+ } else if (category === 'oauth-based') {
form.type = method as AccountType // 'oauth' or 'setup-token'
} else {
form.type = 'apikey'
@@ -3396,6 +3617,12 @@ watch(
antigravityModelMappings.value = []
antigravityModelRestrictionMode.value = 'mapping'
}
+ if (newPlatform !== 'gemini' && newPlatform !== 'anthropic' && accountCategory.value === 'service_account') {
+ accountCategory.value = 'oauth-based'
+ }
+ if (newPlatform !== 'anthropic' && accountCategory.value === 'bedrock') {
+ accountCategory.value = 'oauth-based'
+ }
// Reset Bedrock fields when switching platforms
bedrockAccessKeyId.value = ''
bedrockSecretAccessKey.value = ''
@@ -3404,6 +3631,10 @@ watch(
bedrockForceGlobal.value = false
bedrockAuthMode.value = 'sigv4'
bedrockApiKeyValue.value = ''
+ vertexServiceAccountJson.value = ''
+ vertexProjectId.value = ''
+ vertexClientEmail.value = ''
+ vertexLocation.value = 'global'
// Reset Anthropic/Antigravity-specific settings when switching to other platforms
if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') {
interceptWarmupRequests.value = false
@@ -3489,6 +3720,14 @@ const addModelMapping = () => {
modelMappings.value.push({ from: '', to: '' })
}
+const addOpenAICompactModelMapping = () => {
+ openAICompactModelMappings.value.push({ from: '', to: '' })
+}
+
+const removeOpenAICompactModelMapping = (index: number) => {
+ openAICompactModelMappings.value.splice(index, 1)
+}
+
const removeModelMapping = (index: number) => {
modelMappings.value.splice(index, 1)
}
@@ -3781,6 +4020,7 @@ const resetForm = () => {
editWeeklyResetHour.value = null
editResetTimezone.value = null
modelMappings.value = []
+ openAICompactModelMappings.value = []
modelRestrictionMode.value = 'whitelist'
allowedModels.value = [...claudeModels] // Default fill related models
@@ -3797,6 +4037,7 @@ const resetForm = () => {
interceptWarmupRequests.value = false
autoPauseOnExpired.value = true
openaiPassthroughEnabled.value = false
+ openAICompactMode.value = 'auto'
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
@@ -3825,6 +4066,10 @@ const resetForm = () => {
antigravityAccountType.value = 'oauth'
upstreamBaseUrl.value = ''
upstreamApiKey.value = ''
+ vertexServiceAccountJson.value = ''
+ vertexProjectId.value = ''
+ vertexClientEmail.value = ''
+ vertexLocation.value = 'global'
tempUnschedEnabled.value = false
tempUnschedRules.value = []
geminiOAuthType.value = 'code_assist'
@@ -3874,6 +4119,11 @@ const buildOpenAIExtra = (base?: Record): Record 0 ? extra : undefined
}
@@ -3943,6 +4193,52 @@ const normalizePoolModeRetryCount = (value: number) => {
return normalized
}
+const applyVertexServiceAccountJson = (value: string) => {
+ const raw = value.trim()
+ if (!raw) {
+ vertexProjectId.value = ''
+ vertexClientEmail.value = ''
+ return false
+ }
+ try {
+ const parsed = JSON.parse(raw) as Record
+ const projectId = typeof parsed.project_id === 'string' ? parsed.project_id.trim() : ''
+ const clientEmail = typeof parsed.client_email === 'string' ? parsed.client_email.trim() : ''
+ const privateKey = typeof parsed.private_key === 'string' ? parsed.private_key.trim() : ''
+ if (!projectId || !clientEmail || !privateKey) {
+ appStore.showError(t('admin.accounts.vertexSaJsonMissingFields'))
+ return false
+ }
+ vertexProjectId.value = projectId
+ vertexClientEmail.value = clientEmail
+ vertexServiceAccountJson.value = JSON.stringify(parsed)
+ return true
+ } catch {
+ appStore.showError(t('admin.accounts.vertexSaJsonInvalid'))
+ return false
+ }
+}
+
+const parseVertexServiceAccountJson = () => applyVertexServiceAccountJson(vertexServiceAccountJson.value)
+
+const handleVertexServiceAccountFile = async (event: Event) => {
+ const input = event.target as HTMLInputElement
+ const file = input.files?.[0]
+ if (!file) return
+ try {
+ applyVertexServiceAccountJson(await file.text())
+ } finally {
+ input.value = ''
+ }
+}
+
+const handleVertexServiceAccountDrop = async (event: DragEvent) => {
+ vertexServiceAccountDragActive.value = false
+ const file = event.dataTransfer?.files?.[0]
+ if (!file) return
+ applyVertexServiceAccountJson(await file.text())
+}
+
const handleSubmit = async () => {
// For OAuth-based type, handle OAuth flow (goes to step 2)
if (isOAuthFlow.value) {
@@ -4056,6 +4352,29 @@ const handleSubmit = async () => {
return
}
+ if ((form.platform === 'gemini' || form.platform === 'anthropic') && accountCategory.value === 'service_account') {
+ if (!form.name.trim()) {
+ appStore.showError(t('admin.accounts.pleaseEnterAccountName'))
+ return
+ }
+ if (!parseVertexServiceAccountJson()) {
+ return
+ }
+ if (!vertexLocation.value.trim()) {
+ appStore.showError(t('admin.accounts.vertexLocationRequired'))
+ return
+ }
+ const credentials: Record = {
+ service_account_json: vertexServiceAccountJson.value.trim(),
+ project_id: vertexProjectId.value.trim(),
+ client_email: vertexClientEmail.value.trim(),
+ location: vertexLocation.value.trim(),
+ tier_id: 'vertex'
+ }
+ await createAccountAndFinish(form.platform, 'service_account' as AccountType, credentials)
+ return
+ }
+
// For apikey type, create directly
if (!apiKeyValue.value.trim()) {
appStore.showError(t('admin.accounts.pleaseEnterApiKey'))
@@ -4086,6 +4405,12 @@ const handleSubmit = async () => {
credentials.model_mapping = modelMapping
}
}
+ if (form.platform === 'openai') {
+ const compactModelMapping = buildOpenAICompactModelMapping()
+ if (compactModelMapping) {
+ credentials.compact_model_mapping = compactModelMapping
+ }
+ }
// Add pool mode if enabled
if (poolModeEnabled.value) {
@@ -4198,6 +4523,14 @@ const createAccountAndFinish = async (
finalExtra = quotaExtra
}
}
+ if (platform === 'openai') {
+ const compactModelMapping = buildOpenAICompactModelMapping()
+ if (compactModelMapping) {
+ credentials.compact_model_mapping = compactModelMapping
+ } else {
+ delete credentials.compact_model_mapping
+ }
+ }
await doCreateAccount({
name: form.name,
notes: form.notes,
@@ -4252,6 +4585,12 @@ const handleOpenAIExchange = async (authCode: string) => {
credentials.model_mapping = modelMapping
}
}
+ if (shouldCreateOpenAI) {
+ const compactModelMapping = buildOpenAICompactModelMapping()
+ if (compactModelMapping) {
+ credentials.compact_model_mapping = compactModelMapping
+ }
+ }
// 应用临时不可调度配置
if (!applyTempUnschedConfig(credentials)) {
@@ -4344,6 +4683,12 @@ const handleOpenAIBatchRT = async (refreshTokenInput: string, clientId?: string)
credentials.model_mapping = modelMapping
}
}
+ if (shouldCreateOpenAI) {
+ const compactModelMapping = buildOpenAICompactModelMapping()
+ if (compactModelMapping) {
+ credentials.compact_model_mapping = compactModelMapping
+ }
+ }
// Generate account name; fallback to email if name is empty (ent schema requires NotEmpty)
const baseName = form.name || tokenInfo.email || 'OpenAI OAuth Account'
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index 1da32e2c..56874474 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -52,6 +52,10 @@
v-model="editApiKey"
type="password"
class="input font-mono"
+ autocomplete="new-password"
+ data-1p-ignore
+ data-lpignore="true"
+ data-bwignore="true"
:placeholder="
account.platform === 'openai'
? 'sk-proj-...'
@@ -563,6 +567,221 @@
+
+
+
+
+
Project ID
+
+
{{ t('admin.accounts.vertexSaJsonEditHint') }}
+
+
+
Location
+
+
+
+ {{ option.label }}
+
+
+
+
{{ t('admin.accounts.vertexLocationHint') }}
+
+
+
+
+
+
{{ t('admin.accounts.modelRestriction') }}
+
+
+
+
+
+
+
+ {{ t('admin.accounts.modelWhitelist') }}
+
+
+
+
+
+ {{ t('admin.accounts.modelMapping') }}
+
+
+
+
+
+
+
+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
+ {{
+ t('admin.accounts.supportsAllModels')
+ }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.mapRequestModels') }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.addMapping') }}
+
+
+
+
+
+ + {{ preset.label }}
+
+
+
+
+
+
@@ -1302,6 +1521,64 @@
+
+
+
+
{{ t('admin.accounts.openai.compactMode') }}
+
+ {{ t('admin.accounts.openai.compactModeDesc') }}
+
+
+
+
+
+
+
+ {{ t(openAICompactStatusKey) }}
+
+ {{ t('admin.accounts.openai.compactLastChecked') }}:
+ {{ formatDateTime(new Date(String(account.extra.openai_compact_checked_at))) }}
+
+
+
+
{{ t('admin.accounts.openai.compactModelMapping') }}
+
{{ t('admin.accounts.openai.compactModelMappingDesc') }}
+
+
+ + {{ t('admin.accounts.addMapping') }}
+
+
+
+
@@ -1845,7 +2122,7 @@ import { useAppStore } from '@/stores/app'
import { useAuthStore } from '@/stores/auth'
import { adminAPI } from '@/api/admin'
import { useQuotaNotifyState } from '@/composables/useQuotaNotifyState'
-import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse } from '@/types'
+import type { Account, Proxy, AdminGroup, CheckMixedChannelResponse, OpenAICompactMode } from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
import Select from '@/components/common/Select.vue'
@@ -1855,8 +2132,9 @@ import GroupSelector from '@/components/common/GroupSelector.vue'
import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue'
import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue'
import { applyInterceptWarmup } from '@/components/account/credentialsBuilder'
-import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
+import { formatDateTime, formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format'
import { createStableObjectKeyResolver } from '@/utils/stableObjectKey'
+import { VERTEX_LOCATION_OPTIONS } from '@/constants/account'
import {
OPENAI_WS_MODE_CTX_POOL,
OPENAI_WS_MODE_OFF,
@@ -1925,11 +2203,15 @@ const editBedrockSessionToken = ref('')
const editBedrockRegion = ref('')
const editBedrockForceGlobal = ref(false)
const editBedrockApiKeyValue = ref('')
+const editVertexProjectId = ref('')
+const editVertexClientEmail = ref('')
+const editVertexLocation = ref('us-central1')
const isBedrockAPIKeyMode = computed(() =>
props.account?.type === 'bedrock' &&
(props.account?.credentials as Record
)?.auth_mode === 'apikey'
)
const modelMappings = ref([])
+const openAICompactModelMappings = ref([])
const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist')
const allowedModels = ref([])
const DEFAULT_POOL_MODE_RETRY_COUNT = 3
@@ -1949,6 +2231,7 @@ const antigravityModelMappings = ref([])
const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref([])
const getModelMappingKey = createStableObjectKeyResolver('edit-model-mapping')
+const getOpenAICompactModelMappingKey = createStableObjectKeyResolver('edit-openai-compact-model-mapping')
const getAntigravityModelMappingKey = createStableObjectKeyResolver('edit-antigravity-model-mapping')
const getTempUnschedRuleKey = createStableObjectKeyResolver('edit-temp-unsched-rule')
@@ -1988,6 +2271,7 @@ const customBaseUrl = ref('')
// OpenAI 自动透传开关(OAuth/API Key)
const openaiPassthroughEnabled = ref(false)
+const openAICompactMode = ref('auto')
const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
@@ -2041,9 +2325,27 @@ const openaiResponsesWebSocketV2Mode = computed({
const openAIWSModeConcurrencyHintKey = computed(() =>
resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value)
)
+const openAICompactModeOptions = computed(() => [
+ { value: 'auto', label: t('admin.accounts.openai.compactModeAuto') },
+ { value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') },
+ { value: 'force_off', label: t('admin.accounts.openai.compactModeForceOff') }
+])
const isOpenAIModelRestrictionDisabled = computed(() =>
props.account?.platform === 'openai' && openaiPassthroughEnabled.value
)
+const openAICompactStatusKey = computed(() => {
+ const extra = props.account?.extra as Record | undefined
+ if (!props.account || props.account.platform !== 'openai') return ''
+ const mode = typeof extra?.openai_compact_mode === 'string' ? extra.openai_compact_mode : 'auto'
+ if (mode === 'force_on') return 'admin.accounts.openai.compactSupported'
+ if (mode === 'force_off') return 'admin.accounts.openai.compactUnsupported'
+ if (typeof extra?.openai_compact_supported === 'boolean') {
+ return extra.openai_compact_supported
+ ? 'admin.accounts.openai.compactSupported'
+ : 'admin.accounts.openai.compactUnsupported'
+ }
+ return 'admin.accounts.openai.compactUnknown'
+})
// Computed: current preset mappings based on platform
const presetMappings = computed(() => getPresetMappingsByPlatform(props.account?.platform || 'anthropic'))
@@ -2163,6 +2465,9 @@ const syncFormFromAccount = (newAccount: Account | null) => {
const credentials = newAccount.credentials as Record | undefined
interceptWarmupRequests.value = credentials?.intercept_warmup_requests === true
autoPauseOnExpired.value = newAccount.auto_pause_on_expired === true
+ editVertexProjectId.value = ''
+ editVertexClientEmail.value = ''
+ editVertexLocation.value = 'us-central1'
// Load mixed scheduling setting (only for antigravity accounts)
mixedScheduling.value = false
@@ -2173,6 +2478,8 @@ const syncFormFromAccount = (newAccount: Account | null) => {
// Load OpenAI passthrough toggle (OpenAI OAuth/API Key)
openaiPassthroughEnabled.value = false
+ openAICompactMode.value = 'auto'
+ openAICompactModelMappings.value = []
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
@@ -2180,6 +2487,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
webSearchEmulationMode.value = 'default'
if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) {
openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true
+ openAICompactMode.value = (extra?.openai_compact_mode as OpenAICompactMode) || 'auto'
openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, {
modeKey: 'openai_oauth_responses_websockets_v2_mode',
enabledKey: 'openai_oauth_responses_websockets_v2_enabled',
@@ -2195,6 +2503,11 @@ const syncFormFromAccount = (newAccount: Account | null) => {
if (newAccount.type === 'oauth') {
codexCLIOnlyEnabled.value = extra?.codex_cli_only === true
}
+ const credentials = newAccount.credentials as Record | undefined
+ const compactMappings = credentials?.compact_model_mapping as Record | undefined
+ if (compactMappings && typeof compactMappings === 'object') {
+ openAICompactModelMappings.value = Object.entries(compactMappings).map(([from, to]) => ({ from, to }))
+ }
}
if (newAccount.platform === 'anthropic' && newAccount.type === 'apikey') {
anthropicPassthroughEnabled.value = extra?.anthropic_passthrough === true
@@ -2376,6 +2689,31 @@ const syncFormFromAccount = (newAccount: Account | null) => {
} else if (newAccount.type === 'upstream' && newAccount.credentials) {
const credentials = newAccount.credentials as Record
editBaseUrl.value = (credentials.base_url as string) || ''
+ } else if ((newAccount.platform === 'gemini' || newAccount.platform === 'anthropic') && newAccount.type === 'service_account' && newAccount.credentials) {
+ const credentials = newAccount.credentials as Record
+ editVertexProjectId.value = (credentials.project_id as string) || ''
+ editVertexClientEmail.value = (credentials.client_email as string) || ''
+ editVertexLocation.value = (credentials.location as string) || (credentials.vertex_location as string) || 'us-central1'
+
+ // Load model mappings for service_account
+ const existingMappings = credentials.model_mapping as Record | undefined
+ if (existingMappings && typeof existingMappings === 'object') {
+ const entries = Object.entries(existingMappings)
+ const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
+ if (isWhitelistMode) {
+ modelRestrictionMode.value = 'whitelist'
+ allowedModels.value = entries.map(([from]) => from)
+ modelMappings.value = []
+ } else {
+ modelRestrictionMode.value = 'mapping'
+ modelMappings.value = entries.map(([from, to]) => ({ from, to }))
+ allowedModels.value = []
+ }
+ } else {
+ modelRestrictionMode.value = 'whitelist'
+ modelMappings.value = []
+ allowedModels.value = []
+ }
} else {
const platformDefaultUrl =
newAccount.platform === 'openai'
@@ -2419,6 +2757,15 @@ const syncFormFromAccount = (newAccount: Account | null) => {
editApiKey.value = ''
}
+async function loadTLSProfiles() {
+ try {
+ const profiles = await adminAPI.tlsFingerprintProfiles.list()
+ tlsFingerprintProfiles.value = profiles.map(p => ({ id: p.id, name: p.name }))
+ } catch {
+ tlsFingerprintProfiles.value = []
+ }
+}
+
watch(
[() => props.show, () => props.account],
([show, newAccount], [wasShow, previousAccount]) => {
@@ -2433,15 +2780,6 @@ watch(
{ immediate: true }
)
-const loadTLSProfiles = async () => {
- try {
- const profiles = await adminAPI.tlsFingerprintProfiles.list()
- tlsFingerprintProfiles.value = profiles.map(p => ({ id: p.id, name: p.name }))
- } catch {
- tlsFingerprintProfiles.value = []
- }
-}
-
// Model mapping helpers
const addModelMapping = () => {
modelMappings.value.push({ from: '', to: '' })
@@ -2464,6 +2802,14 @@ const addAntigravityModelMapping = () => {
antigravityModelMappings.value.push({ from: '', to: '' })
}
+const addOpenAICompactModelMapping = () => {
+ openAICompactModelMappings.value.push({ from: '', to: '' })
+}
+
+const removeOpenAICompactModelMapping = (index: number) => {
+ openAICompactModelMappings.value.splice(index, 1)
+}
+
const removeAntigravityModelMapping = (index: number) => {
antigravityModelMappings.value.splice(index, 1)
}
@@ -2907,6 +3253,14 @@ const handleSubmit = async () => {
} else if (currentCredentials.model_mapping) {
newCredentials.model_mapping = currentCredentials.model_mapping
}
+ if (props.account.platform === 'openai') {
+ const compactModelMapping = buildModelMappingObject('mapping', [], openAICompactModelMappings.value)
+ if (compactModelMapping) {
+ newCredentials.compact_model_mapping = compactModelMapping
+ } else {
+ delete newCredentials.compact_model_mapping
+ }
+ }
// Add pool mode if enabled
if (poolModeEnabled.value) {
@@ -2950,6 +3304,46 @@ const handleSubmit = async () => {
return
}
+ updatePayload.credentials = newCredentials
+ } else if ((props.account.platform === 'gemini' || props.account.platform === 'anthropic') && props.account.type === 'service_account') {
+ const currentCredentials = (props.account.credentials as Record) || {}
+ const newCredentials: Record = { ...currentCredentials }
+
+ if (!editVertexProjectId.value.trim()) {
+ appStore.showError(t('admin.accounts.vertexSaJsonMissingProjectId'))
+ return
+ }
+ if (!editVertexClientEmail.value.trim()) {
+ appStore.showError(t('admin.accounts.vertexSaJsonMissingClientEmail'))
+ return
+ }
+ if (!editVertexLocation.value.trim()) {
+ appStore.showError(t('admin.accounts.vertexLocationRequired'))
+ return
+ }
+
+ if (!currentCredentials.service_account_json && !currentCredentials.service_account) {
+ appStore.showError(t('admin.accounts.vertexSaJsonRequired'))
+ return
+ }
+ newCredentials.project_id = editVertexProjectId.value.trim()
+ newCredentials.client_email = editVertexClientEmail.value.trim()
+ newCredentials.location = editVertexLocation.value.trim()
+ newCredentials.tier_id = 'vertex'
+
+ // Add model mapping if configured
+ const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
+ if (modelMapping) {
+ newCredentials.model_mapping = modelMapping
+ } else {
+ delete newCredentials.model_mapping
+ }
+
+ applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit')
+ if (!applyTempUnschedConfig(newCredentials)) {
+ return
+ }
+
updatePayload.credentials = newCredentials
} else if (props.account.type === 'bedrock') {
const currentCredentials = (props.account.credentials as Record) || {}
@@ -3032,6 +3426,12 @@ const handleSubmit = async () => {
// 透传模式保留现有映射
newCredentials.model_mapping = currentCredentials.model_mapping
}
+ const compactModelMapping = buildModelMappingObject('mapping', [], openAICompactModelMappings.value)
+ if (compactModelMapping) {
+ newCredentials.compact_model_mapping = compactModelMapping
+ } else {
+ delete newCredentials.compact_model_mapping
+ }
updatePayload.credentials = newCredentials
}
@@ -3204,6 +3604,11 @@ const handleSubmit = async () => {
delete newExtra.openai_passthrough
delete newExtra.openai_oauth_passthrough
}
+ if (openAICompactMode.value === 'auto') {
+ delete newExtra.openai_compact_mode
+ } else {
+ newExtra.openai_compact_mode = openAICompactMode.value
+ }
if (props.account.type === 'oauth') {
if (codexCLIOnlyEnabled.value) {
diff --git a/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts b/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts
index 7cdf7999..f758e6b0 100644
--- a/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts
+++ b/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts
@@ -122,7 +122,7 @@ describe('AccountStatusIndicator', () => {
}
})
- expect(wrapper.text()).toContain('account.creditsExhausted')
+ expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted')
})
it('模型限流 + overages 启用 + AICredits key 生效 → 普通限流样式(积分耗尽,无 ⚡)', () => {
@@ -157,6 +157,6 @@ describe('AccountStatusIndicator', () => {
expect(wrapper.text()).toContain('CSon45')
expect(wrapper.text()).not.toContain('⚡')
// AICredits 积分耗尽状态应显示
- expect(wrapper.text()).toContain('account.creditsExhausted')
+ expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted')
})
})
diff --git a/frontend/src/components/account/__tests__/AccountTestModal.spec.ts b/frontend/src/components/account/__tests__/AccountTestModal.spec.ts
new file mode 100644
index 00000000..c82a3840
--- /dev/null
+++ b/frontend/src/components/account/__tests__/AccountTestModal.spec.ts
@@ -0,0 +1,150 @@
+import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+import { defineComponent } from 'vue'
+import AccountTestModal from '../AccountTestModal.vue'
+
+const { getAvailableModelsMock } = vi.hoisted(() => ({
+ getAvailableModelsMock: vi.fn()
+}))
+
+vi.mock('@/api/admin', () => ({
+ adminAPI: {
+ accounts: {
+ getAvailableModels: getAvailableModelsMock
+ }
+ }
+}))
+
+vi.mock('@/composables/useClipboard', () => ({
+ useClipboard: () => ({
+ copyToClipboard: vi.fn()
+ })
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key
+ })
+ }
+})
+
+const BaseDialogStub = defineComponent({
+ name: 'BaseDialog',
+ props: { show: { type: Boolean, default: false } },
+ template: '
'
+})
+
+const SelectStub = defineComponent({
+ name: 'SelectStub',
+ props: {
+ modelValue: { type: [String, Number, Boolean, null], default: '' },
+ options: { type: Array, default: () => [] },
+ valueKey: { type: String, default: 'value' },
+ labelKey: { type: String, default: 'label' }
+ },
+ emits: ['update:modelValue'],
+ template: `
+
+
+ {{ option[labelKey] }}
+
+
+ `
+})
+
+const TextAreaStub = defineComponent({
+ name: 'TextArea',
+ props: {
+ modelValue: { type: String, default: '' }
+ },
+ emits: ['update:modelValue'],
+ template: `
+
+ `
+})
+
+function buildAccount() {
+ return {
+ id: 1,
+ name: 'OpenAI OAuth',
+ platform: 'openai',
+ type: 'oauth',
+ status: 'active',
+ credentials: {},
+ extra: {},
+ concurrency: 1,
+ priority: 1,
+ proxy_id: null,
+ auto_pause_on_expired: false
+ } as any
+}
+
+describe('AccountTestModal', () => {
+ const originalFetch = global.fetch
+
+ beforeEach(() => {
+ getAvailableModelsMock.mockReset()
+ getAvailableModelsMock.mockResolvedValue([
+ { id: 'gpt-5.4', display_name: 'GPT-5.4' }
+ ])
+ global.fetch = vi.fn().mockResolvedValue({
+ ok: true,
+ body: {
+ getReader: () => ({
+ read: vi.fn().mockResolvedValue({ done: true, value: undefined })
+ })
+ }
+ } as any)
+ localStorage.setItem('auth_token', 'test-token')
+ })
+
+ afterEach(() => {
+ global.fetch = originalFetch
+ localStorage.clear()
+ })
+
+ it('posts compact mode for OpenAI compact probe', async () => {
+ const wrapper = mount(AccountTestModal, {
+ props: {
+ show: true,
+ account: buildAccount()
+ },
+ global: {
+ stubs: {
+ BaseDialog: BaseDialogStub,
+ Select: SelectStub,
+ TextArea: TextAreaStub,
+ Icon: true
+ }
+ }
+ })
+
+ await flushPromises()
+ ;(wrapper.vm as any).selectedModelId = 'gpt-5.4'
+ ;(wrapper.vm as any).testMode = 'compact'
+ await (wrapper.vm as any).startTest()
+ await flushPromises()
+
+ expect(global.fetch).toHaveBeenCalledTimes(1)
+ const [, options] = (global.fetch as any).mock.calls[0]
+ expect(JSON.parse(options.body)).toMatchObject({
+ model_id: 'gpt-5.4',
+ mode: 'compact'
+ })
+ })
+})
diff --git a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts
index 9158da64..fa4104f6 100644
--- a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts
+++ b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts
@@ -57,6 +57,19 @@ function makeAccount(overrides: Partial): Account {
describe('AccountUsageCell', () => {
beforeEach(() => {
getUsage.mockReset()
+ Object.defineProperty(window, 'matchMedia', {
+ writable: true,
+ value: vi.fn().mockImplementation(() => ({
+ matches: true,
+ media: '(min-width: 768px)',
+ onchange: null,
+ addListener: vi.fn(),
+ removeListener: vi.fn(),
+ addEventListener: vi.fn(),
+ removeEventListener: vi.fn(),
+ dispatchEvent: vi.fn(),
+ }))
+ })
})
it('Antigravity 图片用量会聚合新旧 image 模型', async () => {
@@ -603,4 +616,43 @@ describe('AccountUsageCell', () => {
expect(wrapper.text().trim()).toBe('-')
})
+
+ it('Vertex 账号会在 Gemini 用量窗口里展示 today stats 徽章', async () => {
+ const wrapper = mount(AccountUsageCell, {
+ props: {
+ account: makeAccount({
+ id: 4001,
+ platform: 'gemini',
+ type: 'service_account',
+ credentials: {
+ tier_id: 'vertex',
+ project_id: 'vertex-proj',
+ client_email: 'svc@vertex-proj.iam.gserviceaccount.com',
+ location: 'global'
+ },
+ extra: {}
+ }),
+ todayStats: {
+ requests: 0,
+ tokens: 0,
+ cost: 0,
+ standard_cost: 0,
+ user_cost: 0
+ }
+ },
+ global: {
+ stubs: {
+ UsageProgressBar: true,
+ AccountQuotaInfo: true
+ }
+ }
+ })
+
+ await flushPromises()
+
+ expect(wrapper.text()).toContain('0 req')
+ expect(wrapper.text()).toContain('0')
+ expect(wrapper.text()).toContain('A $0.00')
+ expect(wrapper.text()).toContain('U $0.00')
+ })
})
diff --git a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts
index 7390e723..50d170da 100644
--- a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts
+++ b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts
@@ -178,6 +178,45 @@ describe('BulkEditAccountModal', () => {
expect(wrapper.find('#bulk-edit-openai-ws-mode-enabled').exists()).toBe(false)
})
+ it('OpenAI OAuth 批量编辑应提交 codex_cli_only 字段', async () => {
+ const wrapper = mountModal({
+ selectedPlatforms: ['openai'],
+ selectedTypes: ['oauth']
+ })
+
+ await wrapper.get('#bulk-edit-openai-codex-cli-only-enabled').setValue(true)
+ await wrapper.get('#bulk-edit-openai-codex-cli-only-toggle').trigger('click')
+ await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
+ await flushPromises()
+
+ expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
+ expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
+ extra: {
+ codex_cli_only: true
+ }
+ })
+ })
+
+ it('OpenAI API Key 批量编辑应提交 API Key 专属 WS mode 字段', async () => {
+ const wrapper = mountModal({
+ selectedPlatforms: ['openai'],
+ selectedTypes: ['apikey']
+ })
+
+ await wrapper.get('#bulk-edit-openai-apikey-ws-mode-enabled').setValue(true)
+ await wrapper.get('[data-testid="bulk-edit-openai-apikey-ws-mode-select"]').setValue('ctx_pool')
+ await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
+ await flushPromises()
+
+ expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
+ expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith([1, 2], {
+ extra: {
+ openai_apikey_responses_websockets_v2_mode: 'ctx_pool',
+ openai_apikey_responses_websockets_v2_enabled: true
+ }
+ })
+ })
+
it('OpenAI 账号批量编辑可关闭自动透传', async () => {
const wrapper = mountModal({
selectedPlatforms: ['openai'],
@@ -217,4 +256,41 @@ describe('BulkEditAccountModal', () => {
})
expect(wrapper.text()).toContain('admin.accounts.openai.modelRestrictionDisabledByPassthrough')
})
+
+ it('filtered-results 模式下应提交 filters 而不是 account_ids', async () => {
+ const wrapper = mountModal({
+ accountIds: [],
+ target: {
+ mode: 'filtered',
+ filters: {
+ platform: 'openai',
+ type: 'oauth',
+ status: 'active',
+ group: '12',
+ search: 'bulk-target',
+ privacy_mode: 'training_set_cf_blocked'
+ },
+ previewCount: 5,
+ selectedPlatforms: ['openai'],
+ selectedTypes: ['oauth']
+ }
+ })
+
+ await wrapper.get('#bulk-edit-status-enabled').setValue(true)
+ await wrapper.get('#bulk-edit-account-form').trigger('submit.prevent')
+ await flushPromises()
+
+ expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledTimes(1)
+ expect(adminAPI.accounts.bulkUpdate).toHaveBeenCalledWith({
+ filters: {
+ platform: 'openai',
+ type: 'oauth',
+ status: 'active',
+ group: '12',
+ search: 'bulk-target',
+ privacy_mode: 'training_set_cf_blocked'
+ },
+ status: 'active'
+ })
+ })
})
diff --git a/frontend/src/components/account/__tests__/EditAccountModal.spec.ts b/frontend/src/components/account/__tests__/EditAccountModal.spec.ts
index e3260168..c4e2a9bc 100644
--- a/frontend/src/components/account/__tests__/EditAccountModal.spec.ts
+++ b/frontend/src/components/account/__tests__/EditAccountModal.spec.ts
@@ -26,6 +26,13 @@ vi.mock('@/api/admin', () => ({
accounts: {
update: updateAccountMock,
checkMixedChannelRisk: checkMixedChannelRiskMock
+ },
+ settings: {
+ getWebSearchEmulationConfig: vi.fn().mockResolvedValue({ enabled: false, providers: [] }),
+ getSettings: vi.fn().mockResolvedValue({})
+ },
+ tlsFingerprintProfiles: {
+ list: vi.fn().mockResolvedValue([])
}
}
}))
@@ -82,6 +89,32 @@ const ModelWhitelistSelectorStub = defineComponent({
`
})
+const SelectStub = defineComponent({
+ name: 'SelectStub',
+ props: {
+ modelValue: {
+ type: [String, Number, Boolean, null],
+ default: ''
+ },
+ options: {
+ type: Array,
+ default: () => []
+ }
+ },
+ emits: ['update:modelValue'],
+ template: `
+
+
+ {{ option.label }}
+
+
+ `
+})
+
function buildAccount() {
return {
id: 1,
@@ -119,7 +152,7 @@ function mountModal(account = buildAccount()) {
global: {
stubs: {
BaseDialog: BaseDialogStub,
- Select: true,
+ Select: SelectStub,
Icon: true,
ProxySelector: true,
GroupSelector: true,
@@ -156,4 +189,31 @@ describe('EditAccountModal', () => {
'gpt-5.2': 'gpt-5.2'
})
})
+
+ it('submits OpenAI compact mode and compact-only model mapping', async () => {
+ const account = buildAccount()
+ account.extra = {
+ openai_compact_mode: 'force_on'
+ }
+ account.credentials = {
+ ...account.credentials,
+ compact_model_mapping: {
+ 'gpt-5.4': 'gpt-5.4-openai-compact'
+ }
+ }
+ updateAccountMock.mockReset()
+ checkMixedChannelRiskMock.mockReset()
+ checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
+ updateAccountMock.mockResolvedValue(account)
+
+ const wrapper = mountModal(account)
+
+ await wrapper.get('form#edit-account-form').trigger('submit.prevent')
+
+ expect(updateAccountMock).toHaveBeenCalledTimes(1)
+ expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.openai_compact_mode).toBe('force_on')
+ expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.compact_model_mapping).toEqual({
+ 'gpt-5.4': 'gpt-5.4-openai-compact'
+ })
+ })
})
diff --git a/frontend/src/components/admin/account/AccountBulkActionsBar.vue b/frontend/src/components/admin/account/AccountBulkActionsBar.vue
index 3b987bd0..a632bdd4 100644
--- a/frontend/src/components/admin/account/AccountBulkActionsBar.vue
+++ b/frontend/src/components/admin/account/AccountBulkActionsBar.vue
@@ -1,9 +1,13 @@
-
+
-
+
{{ t('admin.accounts.bulkActions.selected', { count: selectedIds.length }) }}
+
+ {{ t('admin.accounts.bulkEdit.title') }}
+
+
{{ t('admin.accounts.bulkActions.clear') }}
+
- {{ t('admin.accounts.bulkActions.delete') }}
- {{ t('admin.accounts.bulkActions.resetStatus') }}
- {{ t('admin.accounts.bulkActions.refreshToken') }}
- {{ t('admin.accounts.bulkActions.enableScheduling') }}
- {{ t('admin.accounts.bulkActions.disableScheduling') }}
- {{ t('admin.accounts.bulkActions.edit') }}
+
+ {{ t('admin.accounts.bulkActions.delete') }}
+ {{ t('admin.accounts.bulkActions.resetStatus') }}
+ {{ t('admin.accounts.bulkActions.refreshToken') }}
+ {{ t('admin.accounts.bulkActions.enableScheduling') }}
+ {{ t('admin.accounts.bulkActions.disableScheduling') }}
+ {{ t('admin.accounts.bulkActions.edit') }}
+
+
+ {{ t('admin.accounts.bulkEdit.submit') }}
+
diff --git a/frontend/src/components/admin/account/AccountTestModal.vue b/frontend/src/components/admin/account/AccountTestModal.vue
index 67409a7c..2e3db61b 100644
--- a/frontend/src/components/admin/account/AccountTestModal.vue
+++ b/frontend/src/components/admin/account/AccountTestModal.vue
@@ -55,12 +55,12 @@
/>
-
+
@@ -122,25 +122,49 @@
- {{ t('admin.accounts.geminiImagePreview') }}
+ {{ t('admin.accounts.imagePreview') }}
-
+
+
+
+
+
+
+
+
+
+
+
+
@@ -152,8 +176,8 @@
{{
- supportsGeminiImageTest
- ? t('admin.accounts.geminiImageTestMode')
+ supportsImageTest
+ ? t('admin.accounts.imageTestMode')
: t('admin.accounts.testPrompt')
}}
@@ -250,6 +274,7 @@ const testPrompt = ref('')
const loadingModels = ref(false)
let abortController: AbortController | null = null
const generatedImages = ref
([])
+const previewImageUrl = ref('')
const prioritizedGeminiModels = ['gemini-3.1-flash-image', 'gemini-2.5-flash-image', 'gemini-2.5-flash', 'gemini-2.5-pro', 'gemini-3-flash-preview', 'gemini-3-pro-preview', 'gemini-2.0-flash']
const supportsGeminiImageTest = computed(() => {
const modelID = selectedModelId.value.toLowerCase()
@@ -258,6 +283,14 @@ const supportsGeminiImageTest = computed(() => {
return props.account?.platform === 'gemini' || (props.account?.platform === 'antigravity' && props.account?.type === 'apikey')
})
+const supportsOpenAIImageTest = computed(() => {
+ const modelID = selectedModelId.value.toLowerCase()
+ if (!modelID.startsWith('gpt-image-')) return false
+ return props.account?.platform === 'openai'
+})
+
+const supportsImageTest = computed(() => supportsGeminiImageTest.value || supportsOpenAIImageTest.value)
+
const sortTestModels = (models: ClaudeModel[]) => {
const priorityMap = new Map(prioritizedGeminiModels.map((id, index) => [id, index]))
@@ -284,8 +317,8 @@ watch(
)
watch(selectedModelId, () => {
- if (supportsGeminiImageTest.value && !testPrompt.value.trim()) {
- testPrompt.value = t('admin.accounts.geminiImagePromptDefault')
+ if (supportsImageTest.value && !testPrompt.value.trim()) {
+ testPrompt.value = t('admin.accounts.imagePromptDefault')
}
})
@@ -325,6 +358,7 @@ const resetState = () => {
streamingContent.value = ''
errorMessage.value = ''
generatedImages.value = []
+ previewImageUrl.value = ''
}
const handleClose = () => {
@@ -377,7 +411,7 @@ const startTest = async () => {
},
body: JSON.stringify({
model_id: selectedModelId.value,
- prompt: supportsGeminiImageTest.value ? testPrompt.value.trim() : ''
+ prompt: supportsImageTest.value ? testPrompt.value.trim() : ''
}),
signal: abortController.signal
})
@@ -444,8 +478,8 @@ const handleEvent = (event: {
addLine(t('admin.accounts.usingModel', { model: event.model }), 'text-cyan-400')
}
addLine(
- supportsGeminiImageTest.value
- ? t('admin.accounts.sendingGeminiImageRequest')
+ supportsImageTest.value
+ ? t('admin.accounts.sendingImageRequest')
: t('admin.accounts.sendingTestMessage'),
'text-gray-400'
)
@@ -466,7 +500,7 @@ const handleEvent = (event: {
url: event.image_url,
mimeType: event.mime_type
})
- addLine(t('admin.accounts.geminiImageReceived', { count: generatedImages.value.length }), 'text-purple-300')
+ addLine(t('admin.accounts.imageReceived', { count: generatedImages.value.length }), 'text-purple-300')
}
break
@@ -500,3 +534,14 @@ const copyOutput = () => {
copyToClipboard(text, t('admin.accounts.outputCopied'))
}
+
+
diff --git a/frontend/src/components/admin/account/__tests__/AccountTestModal.spec.ts b/frontend/src/components/admin/account/__tests__/AccountTestModal.spec.ts
index 429a905c..eb1a7b9d 100644
--- a/frontend/src/components/admin/account/__tests__/AccountTestModal.spec.ts
+++ b/frontend/src/components/admin/account/__tests__/AccountTestModal.spec.ts
@@ -24,13 +24,13 @@ vi.mock('@/composables/useClipboard', () => ({
vi.mock('vue-i18n', async () => {
const actual = await vi.importActual('vue-i18n')
const messages: Record = {
- 'admin.accounts.geminiImagePromptDefault': 'Generate a cute orange cat astronaut sticker on a clean pastel background.'
+ 'admin.accounts.imagePromptDefault': 'Generate a cute orange cat astronaut sticker on a clean pastel background.'
}
return {
...actual,
useI18n: () => ({
t: (key: string, params?: Record) => {
- if (key === 'admin.accounts.geminiImageReceived' && params?.count) {
+ if (key === 'admin.accounts.imageReceived' && params?.count) {
return `received-${params.count}`
}
return messages[key] || key
@@ -140,7 +140,7 @@ describe('AccountTestModal', () => {
prompt: 'draw a tiny orange cat astronaut'
})
- const preview = wrapper.find('img[alt="gemini-test-image-1"]')
+ const preview = wrapper.find('img[alt="test-image-1"]')
expect(preview.exists()).toBe(true)
expect(preview.attributes('src')).toBe('data:image/png;base64,QUJD')
})
diff --git a/frontend/src/components/admin/channel/ModelTagInput.vue b/frontend/src/components/admin/channel/ModelTagInput.vue
index a1ce4022..b91aa119 100644
--- a/frontend/src/components/admin/channel/ModelTagInput.vue
+++ b/frontend/src/components/admin/channel/ModelTagInput.vue
@@ -27,6 +27,7 @@
@keydown.tab.prevent="addModel"
@keydown.delete="handleBackspace"
@paste="handlePaste"
+ @blur="addModel"
/>
diff --git a/frontend/src/components/admin/channel/types.ts b/frontend/src/components/admin/channel/types.ts
index b3966289..955b6487 100644
--- a/frontend/src/components/admin/channel/types.ts
+++ b/frontend/src/components/admin/channel/types.ts
@@ -187,3 +187,14 @@ export function getPlatformTagClass(platform: string): string {
default: return 'bg-gray-100 text-gray-700 dark:bg-gray-900/30 dark:text-gray-400'
}
}
+
+/** 平台对应的模型文字色(仅 text-*,用于 input/text 场景)— 与 getPlatformTagClass 同色系 */
+export function getPlatformTextClass(platform: string): string {
+ switch (platform) {
+ case 'anthropic': return 'text-orange-700 dark:text-orange-400'
+ case 'openai': return 'text-emerald-700 dark:text-emerald-400'
+ case 'gemini': return 'text-blue-700 dark:text-blue-400'
+ case 'antigravity': return 'text-purple-700 dark:text-purple-400'
+ default: return ''
+ }
+}
diff --git a/frontend/src/components/admin/group/GroupRPMOverridesModal.vue b/frontend/src/components/admin/group/GroupRPMOverridesModal.vue
new file mode 100644
index 00000000..a4b4e536
--- /dev/null
+++ b/frontend/src/components/admin/group/GroupRPMOverridesModal.vue
@@ -0,0 +1,434 @@
+
+
+
+
+
+
+
+ {{ t('admin.groups.platforms.' + group.platform) }}
+
+
|
+
{{ group.name }}
+
|
+
+ {{ t('admin.groups.groupRpmDefault') }}: {{ group.rpm_limit || 0 }}
+
+
+
+
+
+
+ {{ t('admin.groups.addUserRpm') }}
+
+
+
+
+
+
+ #{{ user.id }}
+ {{ user.username || user.email }}
+ {{ user.email }}
+
+
+
+
+
+
+
+ {{ t('common.add') }}
+
+
+
+
+
+
+ {{ t('admin.groups.clearAll') }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.groups.rpmOverrides') }} ({{ localEntries.length }})
+
+
+
+ {{ t('admin.groups.noRpmOverrides') }}
+
+
+
+
+
+
+
+
+ {{ t('admin.groups.unsavedChanges') }}
+
+ {{ t('admin.groups.revertChanges') }}
+
+
+
+
+ {{ t('common.close') }}
+
+
+
+ {{ t('common.save') }}
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue
index bf79bea2..d68f3aa5 100644
--- a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue
+++ b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue
@@ -166,9 +166,10 @@
@@ -294,19 +295,17 @@ const showFinalRate = computed(() => {
})
// 计算最终倍率预览
-const computeFinalRate = (rate: number) => {
- if (!batchFactor.value) return rate
- return parseFloat((rate * batchFactor.value).toFixed(6))
+const computeFinalRate = (rate: number | null | undefined) => {
+ const base = rate ?? props.group?.rate_multiplier ?? 1
+ if (!batchFactor.value) return base
+ return parseFloat((base * batchFactor.value).toFixed(6))
}
// 检测是否有未保存的修改
const isDirty = computed(() => {
if (localEntries.value.length !== serverEntries.value.length) return true
- const serverMap = new Map(serverEntries.value.map(e => [e.user_id, e.rate_multiplier]))
- return localEntries.value.some(e => {
- const serverRate = serverMap.get(e.user_id)
- return serverRate === undefined || serverRate !== e.rate_multiplier
- })
+ const serverMap = new Map(serverEntries.value.map(e => [e.user_id, e.rate_multiplier ?? null]))
+ return localEntries.value.some(e => serverMap.get(e.user_id) !== (e.rate_multiplier ?? null))
})
const paginatedLocalEntries = computed(() => {
@@ -322,7 +321,9 @@ const loadEntries = async () => {
if (!props.group) return
loading.value = true
try {
- serverEntries.value = await adminAPI.groups.getGroupRateMultipliers(props.group.id)
+ const raw = await adminAPI.groups.getGroupRateMultipliers(props.group.id)
+ // 仅显示已设置 rate_multiplier 的条目;rpm_override 在另一个弹窗管理,保留不动
+ serverEntries.value = raw.filter(e => e.rate_multiplier != null)
localEntries.value = cloneEntries(serverEntries.value)
adjustPage()
} catch (error) {
@@ -394,7 +395,8 @@ const handleAddLocal = () => {
user_email: user.email,
user_notes: user.notes || '',
user_status: user.status || 'active',
- rate_multiplier: newRate.value
+ rate_multiplier: newRate.value,
+ rpm_override: null
}
if (idx >= 0) {
localEntries.value[idx] = entry
@@ -409,12 +411,15 @@ const handleAddLocal = () => {
// 本地修改倍率
const updateLocalRate = (userId: number, value: string) => {
+ const entry = localEntries.value.find(e => e.user_id === userId)
+ if (!entry) return
+ if (value.trim() === '') {
+ entry.rate_multiplier = null
+ return
+ }
const num = parseFloat(value)
if (isNaN(num)) return
- const entry = localEntries.value.find(e => e.user_id === userId)
- if (entry) {
- entry.rate_multiplier = num
- }
+ entry.rate_multiplier = num
}
// 本地删除
@@ -427,7 +432,9 @@ const removeLocal = (userId: number) => {
const applyBatchFactor = () => {
if (!batchFactor.value || batchFactor.value <= 0) return
for (const entry of localEntries.value) {
- entry.rate_multiplier = parseFloat((entry.rate_multiplier * batchFactor.value).toFixed(6))
+ if (entry.rate_multiplier != null) {
+ entry.rate_multiplier = parseFloat((entry.rate_multiplier * batchFactor.value).toFixed(6))
+ }
}
batchFactor.value = null
}
@@ -444,15 +451,17 @@ const handleCancel = () => {
adjustPage()
}
-// 保存:一次性提交所有数据
+// 保存:一次性提交所有数据(只提交 rate_multiplier;rpm_override 由独立弹窗管理)
const handleSave = async () => {
if (!props.group) return
saving.value = true
try {
- const entries = localEntries.value.map(e => ({
- user_id: e.user_id,
- rate_multiplier: e.rate_multiplier
- }))
+ const entries = localEntries.value
+ .filter(e => e.rate_multiplier != null)
+ .map(e => ({
+ user_id: e.user_id,
+ rate_multiplier: e.rate_multiplier as number
+ }))
await adminAPI.groups.batchSetGroupRateMultipliers(props.group.id, entries)
appStore.showSuccess(t('admin.groups.rateSaved'))
emit('success')
diff --git a/frontend/src/components/admin/monitor/MonitorActionsCell.vue b/frontend/src/components/admin/monitor/MonitorActionsCell.vue
new file mode 100644
index 00000000..74aa4017
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorActionsCell.vue
@@ -0,0 +1,45 @@
+
+
+
+
+ {{ t('admin.channelMonitor.runNow') }}
+
+
+
+ {{ t('common.edit') }}
+
+
+
+ {{ t('common.delete') }}
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue b/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
new file mode 100644
index 00000000..0d6b4ace
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
@@ -0,0 +1,301 @@
+
+
+
+
+
{{ t('admin.channelMonitor.advanced.headers') }}
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.advanced.headerAddRow') }}
+
+
+
{{ headersError }}
+
+ {{ t('admin.channelMonitor.advanced.headersHint') }}
+
+
+
+
+
+
{{ t('admin.channelMonitor.advanced.bodyMode') }}
+
+
+ {{ opt.label }}
+
+
+
+ {{ bodyModeHint }}
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.advanced.bodyJson') }}
+
+ {{ t('admin.channelMonitor.advanced.bodyJsonFormat') }}
+
+
+
+
{{ bodyError }}
+
+ {{ t('admin.channelMonitor.advanced.bodyJsonHint') }}
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorFiltersBar.vue b/frontend/src/components/admin/monitor/MonitorFiltersBar.vue
new file mode 100644
index 00000000..eb2a5c78
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorFiltersBar.vue
@@ -0,0 +1,104 @@
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.template.manageButton') }}
+
+
+
+ {{ t('admin.channelMonitor.createButton') }}
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorFormDialog.vue b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
new file mode 100644
index 00000000..21fa4715
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorFormDialog.vue
@@ -0,0 +1,459 @@
+
+
+
+
+
+
+
+ {{ t('common.cancel') }}
+
+
+ {{ submitting
+ ? t('common.submitting')
+ : editing ? t('common.update') : t('common.create') }}
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue b/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue
new file mode 100644
index 00000000..8df8d586
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue
@@ -0,0 +1,119 @@
+
+
+
+
+ {{ t('admin.channelMonitor.form.selectKeyHint') }}
+
+
+
+
+
+ {{ t('common.loading') }}
+
+
+ {{ t('admin.channelMonitor.form.noActiveKey') }}
+
+
+
+
+
+ {{ t('common.name') }}
+ {{ t('keys.apiKey') }}
+ {{ t('keys.group') }}
+
+
+
+
+ {{ k.name }}
+ {{ maskApiKey(k.key) }}
+
+
+ —
+
+
+
+
+
+
+
+
+
+ {{ t('common.cancel') }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorPrimaryModelCell.vue b/frontend/src/components/admin/monitor/MonitorPrimaryModelCell.vue
new file mode 100644
index 00000000..eccec828
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorPrimaryModelCell.vue
@@ -0,0 +1,71 @@
+
+
+
{{ row.primary_model }}
+
+
+
+ {{ statusLabel(row.primary_status) }}
+
+
+
+
+ {{ row.primary_model }}
+
+ {{ statusLabel(row.primary_status) }}
+
+
+
+ {{ t('monitorCommon.extraModelsEmpty') }}
+
+
+
+ {{ t('monitorCommon.extraModelsHeader') }}
+
+
+
+
+ {{ t('admin.channelMonitor.columns.primaryModel') }}
+ {{ t('admin.channelMonitor.columns.actions') }}
+ {{ t('admin.channelMonitor.columns.latency') }}
+
+
+
+
+ {{ m.model }}
+
+
+ {{ statusLabel(m.status) }}
+
+
+ {{ formatLatency(m.latency_ms) }}
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorRunResultDialog.vue b/frontend/src/components/admin/monitor/MonitorRunResultDialog.vue
new file mode 100644
index 00000000..02fa6e8d
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorRunResultDialog.vue
@@ -0,0 +1,56 @@
+
+
+
+
+
+ {{ r.model }}
+ {{ r.message }}
+
+
+
+ {{ statusLabel(r.status) }}
+
+ {{ formatLatency(r.latency_ms) }} ms
+
+
+
+
+
+
+ {{ t('common.close') }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorTemplateApplyPickerDialog.vue b/frontend/src/components/admin/monitor/MonitorTemplateApplyPickerDialog.vue
new file mode 100644
index 00000000..427b75ff
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorTemplateApplyPickerDialog.vue
@@ -0,0 +1,174 @@
+
+
+
+ {{ t('admin.channelMonitor.template.applyPickerHint') }}
+
+
+
+ {{ t('common.loading') }}
+
+
+
+ {{ t('admin.channelMonitor.template.applyPickerEmpty') }}
+
+
+
+
+
+
+ {{ t('common.selectAll') }}
+
+
+ {{ t('admin.channelMonitor.template.selectNone') }}
+
+
+ {{ t('admin.channelMonitor.template.selectedCount', {
+ n: selectedIds.length,
+ total: monitors.length,
+ }) }}
+
+
+
+
+
+
+
+
+
+ {{ t('common.cancel') }}
+
+
+ {{ submitting
+ ? t('common.submitting')
+ : t('admin.channelMonitor.template.applyPickerConfirm', { n: selectedIds.length }) }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/monitor/MonitorTemplateManagerDialog.vue b/frontend/src/components/admin/monitor/MonitorTemplateManagerDialog.vue
new file mode 100644
index 00000000..3a03f5bc
--- /dev/null
+++ b/frontend/src/components/admin/monitor/MonitorTemplateManagerDialog.vue
@@ -0,0 +1,447 @@
+
+
+
+
+
+
+ {{ tab.label }}
+
+ {{ countByProvider[tab.value] }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.template.createButton') }}
+
+
+
+
+ {{ t('common.loading') }}
+
+
+
+ {{ t('admin.channelMonitor.template.emptyState') }}
+
+
+
+
+
+
+ {{ tpl.name }}
+
+ {{ modeLabel(tpl.body_override_mode) }}
+
+
+ {{ t('admin.channelMonitor.template.associatedCount', { n: tpl.associated_monitors }) }}
+
+
+
+ {{ tpl.description }}
+
+
+ {{ t('admin.channelMonitor.template.headersSummary', {
+ n: Object.keys(tpl.extra_headers || {}).length,
+ }) }}
+
+
+
+
+
+ {{ t('admin.channelMonitor.template.applyButton') }}
+
+
+ {{ t('common.edit') }}
+
+
+ {{ t('common.delete') }}
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.template.form.name') }}
+ *
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.form.provider') }}
+ *
+
+
+
+ {{ opt.label }}
+
+
+
+
+
+
+ {{ t('admin.channelMonitor.template.form.description') }}
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('common.back') }}
+
+
+
+
+
+ {{ t('common.close') }}
+
+
+ {{ submitting ? t('common.submitting') : editing === 'new' ? t('common.create') : t('common.update') }}
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/user/UserAllowedGroupsModal.vue b/frontend/src/components/admin/user/UserAllowedGroupsModal.vue
index bccc22c7..ac822d41 100644
--- a/frontend/src/components/admin/user/UserAllowedGroupsModal.vue
+++ b/frontend/src/components/admin/user/UserAllowedGroupsModal.vue
@@ -81,7 +81,7 @@
+
+
{{ t('admin.users.form.rpmLimit') }}
+
+
{{ t('admin.users.form.rpmLimitHint') }}
+
@@ -57,7 +69,7 @@ import Icon from '@/components/icons/Icon.vue'
const props = defineProps<{ show: boolean }>()
const emit = defineEmits(['close', 'success']); const { t } = useI18n()
-const form = reactive({ email: '', password: '', username: '', notes: '', balance: 0, concurrency: 1 })
+const form = reactive({ email: '', password: '', username: '', notes: '', balance: 0, concurrency: 1, rpm_limit: 0 })
const { loading, submit } = useForm({
form,
@@ -68,7 +80,7 @@ const { loading, submit } = useForm({
successMsg: t('admin.users.userCreated')
})
-watch(() => props.show, (v) => { if(v) Object.assign(form, { email: '', password: '', username: '', notes: '', balance: 0, concurrency: 1 }) })
+watch(() => props.show, (v) => { if(v) Object.assign(form, { email: '', password: '', username: '', notes: '', balance: 0, concurrency: 1, rpm_limit: 0 }) })
const generateRandomPassword = () => {
const chars = 'ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjkmnpqrstuvwxyz23456789!@#$%^&*'
diff --git a/frontend/src/components/admin/user/UserEditModal.vue b/frontend/src/components/admin/user/UserEditModal.vue
index 70ebd2d3..de234986 100644
--- a/frontend/src/components/admin/user/UserEditModal.vue
+++ b/frontend/src/components/admin/user/UserEditModal.vue
@@ -37,6 +37,18 @@
{{ t('admin.users.columns.concurrency') }}
+
+
{{ t('admin.users.form.rpmLimit') }}
+
+
{{ t('admin.users.form.rpmLimitHint') }}
+
@@ -66,11 +78,11 @@ const emit = defineEmits(['close', 'success'])
const { t } = useI18n(); const appStore = useAppStore(); const { copyToClipboard } = useClipboard()
const submitting = ref(false); const passwordCopied = ref(false)
-const form = reactive({ email: '', password: '', username: '', notes: '', concurrency: 1, customAttributes: {} as UserAttributeValuesMap })
+const form = reactive({ email: '', password: '', username: '', notes: '', concurrency: 1, rpm_limit: 0, customAttributes: {} as UserAttributeValuesMap })
watch(() => props.user, (u) => {
if (u) {
- Object.assign(form, { email: u.email, password: '', username: u.username || '', notes: u.notes || '', concurrency: u.concurrency, customAttributes: {} })
+ Object.assign(form, { email: u.email, password: '', username: u.username || '', notes: u.notes || '', concurrency: u.concurrency, rpm_limit: u.rpm_limit ?? 0, customAttributes: {} })
passwordCopied.value = false
}
}, { immediate: true })
@@ -97,7 +109,7 @@ const handleUpdateUser = async () => {
}
submitting.value = true
try {
- const data: any = { email: form.email, username: form.username, notes: form.notes, concurrency: form.concurrency }
+ const data: any = { email: form.email, username: form.username, notes: form.notes, concurrency: form.concurrency, rpm_limit: form.rpm_limit }
if (form.password.trim()) data.password = form.password.trim()
await adminAPI.users.update(props.user.id, data)
if (Object.keys(form.customAttributes).length > 0) await adminAPI.userAttributes.updateUserAttributeValues(props.user.id, form.customAttributes)
diff --git a/frontend/src/components/auth/LinuxDoOAuthSection.vue b/frontend/src/components/auth/LinuxDoOAuthSection.vue
index c740d06f..6b245123 100644
--- a/frontend/src/components/auth/LinuxDoOAuthSection.vue
+++ b/frontend/src/components/auth/LinuxDoOAuthSection.vue
@@ -42,9 +42,11 @@
+
+
diff --git a/frontend/src/components/auth/TotpLoginModal.vue b/frontend/src/components/auth/TotpLoginModal.vue
index 03fa718d..0ae2f482 100644
--- a/frontend/src/components/auth/TotpLoginModal.vue
+++ b/frontend/src/components/auth/TotpLoginModal.vue
@@ -47,11 +47,6 @@
-
-
- {{ error }}
-
-
import { ref, watch, nextTick, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
+import { useAppStore } from '@/stores'
defineProps<{
tempToken: string
@@ -81,9 +77,9 @@ const emit = defineEmits<{
}>()
const { t } = useI18n()
+const appStore = useAppStore()
const verifying = ref(false)
-const error = ref('')
const code = ref(['', '', '', '', '', ''])
const inputRefs = ref<(HTMLInputElement | null)[]>([])
@@ -100,7 +96,9 @@ watch(
defineExpose({
setVerifying: (value: boolean) => { verifying.value = value },
setError: (message: string) => {
- error.value = message
+ if (message) {
+ appStore.showError(message)
+ }
code.value = ['', '', '', '', '', '']
// Clear input DOM values
inputRefs.value.forEach(input => {
diff --git a/frontend/src/components/auth/WechatOAuthSection.vue b/frontend/src/components/auth/WechatOAuthSection.vue
new file mode 100644
index 00000000..c1b5be2e
--- /dev/null
+++ b/frontend/src/components/auth/WechatOAuthSection.vue
@@ -0,0 +1,96 @@
+
+
+
+
+ W
+
+ {{ t('auth.oidc.signIn', { providerName }) }}
+
+
+
+ {{ disabledHint }}
+
+
+
+
+
+ {{ t('auth.oauthOrContinue') }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts b/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
new file mode 100644
index 00000000..1e462e29
--- /dev/null
+++ b/frontend/src/components/auth/__tests__/PendingOAuthCreateAccountForm.spec.ts
@@ -0,0 +1,205 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+import PendingOAuthCreateAccountForm from '../PendingOAuthCreateAccountForm.vue'
+
+const sendVerifyCode = vi.fn()
+const sendPendingOAuthVerifyCode = vi.fn()
+const getPublicSettings = vi.fn()
+const showError = vi.fn()
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key
+ })
+ }
+})
+
+vi.mock('@/api/auth', async () => {
+ const actual = await vi.importActual('@/api/auth')
+ return {
+ ...actual,
+ sendVerifyCode: (...args: any[]) => sendVerifyCode(...args),
+ sendPendingOAuthVerifyCode: (...args: any[]) => sendPendingOAuthVerifyCode(...args),
+ getPublicSettings: (...args: any[]) => getPublicSettings(...args)
+ }
+})
+
+vi.mock('@/stores', () => ({
+ useAppStore: () => ({
+ showError
+ })
+}))
+
+describe('PendingOAuthCreateAccountForm', () => {
+ beforeEach(() => {
+ sendVerifyCode.mockReset()
+ sendPendingOAuthVerifyCode.mockReset()
+ getPublicSettings.mockReset()
+ showError.mockReset()
+ getPublicSettings.mockResolvedValue({
+ turnstile_enabled: false,
+ turnstile_site_key: ''
+ })
+ })
+
+ it('emits trimmed email, password, and verify code on submit', async () => {
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: 'prefill@example.com',
+ isSubmitting: false
+ }
+ })
+
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' user@example.com ')
+ await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue(' 246810 ')
+ await wrapper.get('form').trigger('submit.prevent')
+
+ expect(wrapper.emitted('submit')).toEqual([
+ [
+ {
+ email: 'user@example.com',
+ password: 'secret-123',
+ verifyCode: '246810'
+ }
+ ]
+ ])
+ })
+
+ it('renders action labels through i18n keys', () => {
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ testIdPrefix: 'linuxdo',
+ initialEmail: '',
+ isSubmitting: false
+ }
+ })
+
+ expect(wrapper.text()).toContain('auth.createAccount')
+ expect(wrapper.text()).toContain('auth.alreadyHaveAccount')
+ })
+
+ it('shows and emits invitation code when invitation-only signup is enabled', async () => {
+ getPublicSettings.mockResolvedValue({
+ invitation_code_enabled: true,
+ turnstile_enabled: false,
+ turnstile_site_key: ''
+ })
+
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: 'prefill@example.com',
+ isSubmitting: false
+ }
+ })
+
+ await flushPromises()
+ await wrapper.get('[data-testid="linuxdo-create-account-password"]').setValue('secret-123')
+ await wrapper.get('[data-testid="linuxdo-create-account-verify-code"]').setValue('246810')
+ await wrapper.get('[data-testid="linuxdo-create-account-invitation-code"]').setValue(' INVITE123 ')
+ await wrapper.get('form').trigger('submit.prevent')
+
+ expect(wrapper.emitted('submit')).toEqual([
+ [
+ {
+ email: 'prefill@example.com',
+ password: 'secret-123',
+ verifyCode: '246810',
+ invitationCode: 'INVITE123'
+ }
+ ]
+ ])
+ })
+
+ it('sends a verify code for the trimmed email value', async () => {
+ sendPendingOAuthVerifyCode.mockResolvedValue({
+ message: 'sent',
+ countdown: 60
+ })
+
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: '',
+ isSubmitting: false
+ }
+ })
+
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' user@example.com ')
+ await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
+ email: 'user@example.com'
+ })
+ })
+
+ it('shows send-code failures via toast without rendering inline error text', async () => {
+ sendPendingOAuthVerifyCode.mockRejectedValue(new Error('send failed'))
+
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ testIdPrefix: 'linuxdo',
+ initialEmail: '',
+ isSubmitting: false
+ }
+ })
+
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue('user@example.com')
+ await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(showError).toHaveBeenCalledWith('send failed')
+ expect(wrapper.text()).not.toContain('send failed')
+ })
+
+ it('requires a turnstile token before sending a verify code when turnstile is enabled', async () => {
+ getPublicSettings.mockResolvedValue({
+ turnstile_enabled: true,
+ turnstile_site_key: 'site-key'
+ })
+ sendPendingOAuthVerifyCode.mockResolvedValue({
+ message: 'sent',
+ countdown: 60
+ })
+
+ const wrapper = mount(PendingOAuthCreateAccountForm, {
+ props: {
+ providerName: 'LinuxDo',
+ testIdPrefix: 'linuxdo',
+ initialEmail: '',
+ isSubmitting: false
+ },
+ global: {
+ stubs: {
+ TurnstileWidget: {
+ template: 'verify '
+ }
+ }
+ }
+ })
+
+ await flushPromises()
+ await wrapper.get('[data-testid="linuxdo-create-account-email"]').setValue(' user@example.com ')
+
+ expect(wrapper.get('[data-testid="linuxdo-create-account-send-code"]').attributes('disabled')).toBeDefined()
+
+ await wrapper.get('[data-testid="turnstile-verify"]').trigger('click')
+ await wrapper.get('[data-testid="linuxdo-create-account-send-code"]').trigger('click')
+ await flushPromises()
+
+ expect(sendPendingOAuthVerifyCode).toHaveBeenCalledWith({
+ email: 'user@example.com',
+ turnstile_token: 'turnstile-token'
+ })
+ })
+})
diff --git a/frontend/src/components/auth/__tests__/TotpLoginModal.spec.ts b/frontend/src/components/auth/__tests__/TotpLoginModal.spec.ts
new file mode 100644
index 00000000..06fbe397
--- /dev/null
+++ b/frontend/src/components/auth/__tests__/TotpLoginModal.spec.ts
@@ -0,0 +1,41 @@
+import { mount } from '@vue/test-utils'
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import TotpLoginModal from '@/components/auth/TotpLoginModal.vue'
+
+const { showErrorMock } = vi.hoisted(() => ({
+ showErrorMock: vi.fn(),
+}))
+
+vi.mock('vue-i18n', () => ({
+ useI18n: () => ({
+ t: (key: string) => key,
+ }),
+}))
+
+vi.mock('@/stores', () => ({
+ useAppStore: () => ({
+ showError: (...args: any[]) => showErrorMock(...args),
+ }),
+}))
+
+describe('TotpLoginModal', () => {
+ beforeEach(() => {
+ showErrorMock.mockReset()
+ })
+
+ it('sends verification errors to toast and does not render inline red text', async () => {
+ const wrapper = mount(TotpLoginModal, {
+ props: {
+ tempToken: 'temp-token',
+ userEmailMasked: 'u***@example.com',
+ },
+ })
+
+ ;(wrapper.vm as unknown as { setError: (message: string) => void }).setError('Invalid code')
+ await wrapper.vm.$nextTick()
+
+ expect(showErrorMock).toHaveBeenCalledWith('Invalid code')
+ expect(wrapper.text()).not.toContain('Invalid code')
+ expect(wrapper.find('.bg-red-50').exists()).toBe(false)
+ })
+})
diff --git a/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts
new file mode 100644
index 00000000..2f269e0b
--- /dev/null
+++ b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts
@@ -0,0 +1,238 @@
+import { mount } from '@vue/test-utils'
+import { createPinia, setActivePinia } from 'pinia'
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+import WechatOAuthSection from '@/components/auth/WechatOAuthSection.vue'
+import { useAppStore } from '@/stores'
+import type { PublicSettings } from '@/types'
+
+const routeState = vi.hoisted(() => ({
+ query: {} as Record,
+}))
+
+const locationState = vi.hoisted(() => ({
+ current: { href: 'http://localhost/login' } as { href: string },
+}))
+
+let pinia: ReturnType
+
+vi.mock('vue-router', () => ({
+ useRoute: () => routeState,
+}))
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ locale: { value: 'en' },
+ t: (key: string, params?: Record) => {
+ if (key === 'auth.wechatProviderName') {
+ return 'Mock WeChat'
+ }
+ if (key === 'auth.oidc.signIn') {
+ return `Continue with ${params?.providerName ?? ''}`.trim()
+ }
+ if (key === 'auth.oauthFlow.wechatSystemBrowserOnly') {
+ return 'MOCK-SYSTEM-BROWSER-ONLY'
+ }
+ if (key === 'auth.oauthFlow.wechatBrowserOnly') {
+ return 'MOCK-WECHAT-BROWSER-ONLY'
+ }
+ if (key === 'auth.oauthFlow.wechatNotConfigured') {
+ return 'MOCK-NOT-CONFIGURED'
+ }
+ if (key === 'auth.oauthOrContinue') {
+ return 'or continue'
+ }
+ return key
+ },
+ }),
+ }
+})
+
+type WeChatPublicSettings = PublicSettings & {
+ wechat_oauth_open_enabled?: boolean
+ wechat_oauth_mp_enabled?: boolean
+}
+
+function buildPublicSettings(overrides: Partial = {}): WeChatPublicSettings {
+ return {
+ registration_enabled: true,
+ email_verify_enabled: false,
+ force_email_on_third_party_signup: false,
+ registration_email_suffix_whitelist: [],
+ promo_code_enabled: true,
+ password_reset_enabled: false,
+ invitation_code_enabled: false,
+ turnstile_enabled: false,
+ turnstile_site_key: '',
+ site_name: 'Sub2API',
+ site_logo: '',
+ site_subtitle: '',
+ api_base_url: '/api/v1',
+ contact_info: '',
+ doc_url: '',
+ home_content: '',
+ hide_ccs_import_button: false,
+ payment_enabled: false,
+ table_default_page_size: 20,
+ table_page_size_options: [10, 20, 50, 100],
+ custom_menu_items: [],
+ custom_endpoints: [],
+ linuxdo_oauth_enabled: false,
+ wechat_oauth_enabled: true,
+ oidc_oauth_enabled: false,
+ oidc_oauth_provider_name: 'OIDC',
+ backend_mode_enabled: false,
+ version: 'test',
+ balance_low_notify_enabled: false,
+ account_quota_notify_enabled: false,
+ balance_low_notify_threshold: 0,
+ ...overrides,
+ }
+}
+
+function seedPublicSettings(overrides: Partial = {}): void {
+ const appStore = useAppStore()
+ const settings = buildPublicSettings(overrides)
+ appStore.cachedPublicSettings = settings
+ appStore.publicSettingsLoaded = true
+}
+
+describe('WechatOAuthSection', () => {
+ beforeEach(() => {
+ pinia = createPinia()
+ setActivePinia(pinia)
+ routeState.query = { redirect: '/billing?plan=pro' }
+ locationState.current = { href: 'http://localhost/login' }
+ Object.defineProperty(window, 'location', {
+ configurable: true,
+ value: locationState.current,
+ })
+ Object.defineProperty(window.navigator, 'userAgent', {
+ configurable: true,
+ value: 'Mozilla/5.0',
+ })
+ })
+
+ afterEach(() => {
+ vi.unstubAllGlobals()
+ })
+
+ it('starts the open WeChat OAuth flow with the current redirect target when open mode is configured', async () => {
+ seedPublicSettings({
+ wechat_oauth_open_enabled: true,
+ wechat_oauth_mp_enabled: false,
+ })
+ const wrapper = mount(WechatOAuthSection, {
+ global: {
+ plugins: [pinia],
+ },
+ })
+
+ expect(wrapper.text()).toContain('Mock WeChat')
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toContain(
+ '/api/v1/auth/oauth/wechat/start?mode=open&redirect=%2Fbilling%3Fplan%3Dpro'
+ )
+ })
+
+ it('uses mp mode inside the WeChat browser when mp mode is configured', async () => {
+ Object.defineProperty(window.navigator, 'userAgent', {
+ configurable: true,
+ value: 'Mozilla/5.0 MicroMessenger',
+ })
+ seedPublicSettings({
+ wechat_oauth_open_enabled: false,
+ wechat_oauth_mp_enabled: true,
+ })
+ const wrapper = mount(WechatOAuthSection, {
+ global: {
+ plugins: [pinia],
+ },
+ })
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toContain(
+ '/api/v1/auth/oauth/wechat/start?mode=mp&redirect=%2Fbilling%3Fplan%3Dpro'
+ )
+ })
+
+ it('disables the button outside the WeChat browser when only mp mode is configured', async () => {
+ seedPublicSettings({
+ wechat_oauth_open_enabled: false,
+ wechat_oauth_mp_enabled: true,
+ })
+ const wrapper = mount(WechatOAuthSection, {
+ global: {
+ plugins: [pinia],
+ },
+ })
+
+ expect(wrapper.get('button').attributes('disabled')).toBeDefined()
+ expect(wrapper.text()).toContain('MOCK-WECHAT-BROWSER-ONLY')
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toBe('http://localhost/login')
+ })
+
+ it('disables the button inside the WeChat browser when only open mode is configured', async () => {
+ Object.defineProperty(window.navigator, 'userAgent', {
+ configurable: true,
+ value: 'Mozilla/5.0 MicroMessenger',
+ })
+ seedPublicSettings({
+ wechat_oauth_open_enabled: true,
+ wechat_oauth_mp_enabled: false,
+ })
+ const wrapper = mount(WechatOAuthSection, {
+ global: {
+ plugins: [pinia],
+ },
+ })
+
+ expect(wrapper.get('button').attributes('disabled')).toBeDefined()
+ expect(wrapper.text()).toContain('MOCK-SYSTEM-BROWSER-ONLY')
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toBe('http://localhost/login')
+ })
+
+ it('uses the legacy overall enabled flag when per-mode settings are not present', async () => {
+ seedPublicSettings({
+ wechat_oauth_enabled: true,
+ })
+ const wrapper = mount(WechatOAuthSection, {
+ global: {
+ plugins: [pinia],
+ },
+ })
+
+ await wrapper.get('button').trigger('click')
+
+ expect(locationState.current.href).toContain(
+ '/api/v1/auth/oauth/wechat/start?mode=open&redirect=%2Fbilling%3Fplan%3Dpro'
+ )
+ })
+
+ it('shows the localized not-configured hint when WeChat OAuth is unavailable', async () => {
+ seedPublicSettings({
+ wechat_oauth_enabled: false,
+ wechat_oauth_open_enabled: false,
+ wechat_oauth_mp_enabled: false,
+ })
+
+ const wrapper = mount(WechatOAuthSection, {
+ global: {
+ plugins: [pinia],
+ },
+ })
+
+ expect(wrapper.text()).toContain('MOCK-NOT-CONFIGURED')
+ })
+})
diff --git a/frontend/src/components/channels/AvailableChannelsTable.vue b/frontend/src/components/channels/AvailableChannelsTable.vue
new file mode 100644
index 00000000..5b9c0eba
--- /dev/null
+++ b/frontend/src/components/channels/AvailableChannelsTable.vue
@@ -0,0 +1,189 @@
+
+
+
+
+
+ {{ columns.name }}
+ {{ columns.description }}
+ {{ columns.platform }}
+ {{ columns.groups }}
+ {{ columns.supportedModels }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ emptyLabel }}
+
+
+
+
+
+
+
+
+ {{ channel.name }}
+
+
+
+
+ {{ channel.description }}
+ -
+
+
+
+
+
+
+ {{ section.platform }}
+
+
+
+
+
+
+
+
+
+ {{ t('availableChannels.exclusive') }}
+
+
+
+
+
+
+ {{ t('availableChannels.public') }}
+
+
+
+
-
+
+
+
+
+
+
+
+
+ {{ noModelsLabel }}
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/channels/PricingRow.vue b/frontend/src/components/channels/PricingRow.vue
new file mode 100644
index 00000000..4134593b
--- /dev/null
+++ b/frontend/src/components/channels/PricingRow.vue
@@ -0,0 +1,25 @@
+
+
+ {{ label }}
+ {{ display }}
+
+
+
+
diff --git a/frontend/src/components/channels/SupportedModelChip.vue b/frontend/src/components/channels/SupportedModelChip.vue
new file mode 100644
index 00000000..3fe32e2f
--- /dev/null
+++ b/frontend/src/components/channels/SupportedModelChip.vue
@@ -0,0 +1,301 @@
+
+
+
+
+
+ {{ model.platform }}
+
+ {{ model.name }}
+
+
+
+
+
+
+
+ {{ model.name }}
+
+ {{ model.platform }}
+
+
+
+
+
+ {{ noPricingLabel }}
+
+
+
+
+ {{ t(prefixKey('billingMode')) }}
+ {{ billingModeLabel }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t(prefixKey('intervals')) }}
+
+
+
+
+ {{ iv.tier_label }}
+ {{ formatRange(iv.min_tokens, iv.max_tokens) }}
+
+ {{ formatInterval(iv, model.pricing.billing_mode) }}
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/common/AutoRefreshButton.vue b/frontend/src/components/common/AutoRefreshButton.vue
new file mode 100644
index 00000000..797c8752
--- /dev/null
+++ b/frontend/src/components/common/AutoRefreshButton.vue
@@ -0,0 +1,82 @@
+
+
+
+
+
+
+
+ {{ enabled
+ ? t('common.autoRefresh.countdown', { seconds: countdown })
+ : t('common.autoRefresh.title')
+ }}
+
+
+
+
+
+
+ {{ t('common.autoRefresh.enable') }}
+
+
+
+
+
+
+ {{ t('common.autoRefresh.seconds', { n: sec }) }}
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/common/GroupBadge.vue b/frontend/src/components/common/GroupBadge.vue
index 83f4b8aa..3303d909 100644
--- a/frontend/src/components/common/GroupBadge.vue
+++ b/frontend/src/components/common/GroupBadge.vue
@@ -37,13 +37,20 @@ interface Props {
userRateMultiplier?: number | null // 用户专属倍率
showRate?: boolean
daysRemaining?: number | null // 剩余天数(订阅类型时使用)
+ /**
+ * 订阅分组默认在右侧 label 展示"订阅"或剩余天数;
+ * 开启后订阅分组也改为显示倍率(保留订阅主题色 label,配合可用渠道这类
+ * 只关心费率、不关心有效期的场景)。
+ */
+ alwaysShowRate?: boolean
}
const props = withDefaults(defineProps(), {
subscriptionType: 'standard',
showRate: true,
daysRemaining: null,
- userRateMultiplier: null
+ userRateMultiplier: null,
+ alwaysShowRate: false
})
const { t } = useI18n()
@@ -71,7 +78,8 @@ const showLabel = computed(() => {
// Label text
const labelText = computed(() => {
- if (isSubscription.value) {
+ const rateLabel = props.rateMultiplier !== undefined ? `${props.rateMultiplier}x` : ''
+ if (isSubscription.value && !props.alwaysShowRate) {
// 如果有剩余天数,显示天数
if (props.daysRemaining !== null && props.daysRemaining !== undefined) {
if (props.daysRemaining <= 0) {
@@ -82,7 +90,7 @@ const labelText = computed(() => {
// 否则显示"订阅"
return t('groups.subscription')
}
- return props.rateMultiplier !== undefined ? `${props.rateMultiplier}x` : ''
+ return rateLabel
})
// Label style based on type and days remaining
diff --git a/frontend/src/components/common/HelpTooltip.vue b/frontend/src/components/common/HelpTooltip.vue
index e95052da..d2a2e48f 100644
--- a/frontend/src/components/common/HelpTooltip.vue
+++ b/frontend/src/components/common/HelpTooltip.vue
@@ -1,23 +1,69 @@
@@ -35,6 +95,7 @@ function updatePosition() {
class="group relative ml-1 inline-flex items-center align-middle"
@mouseenter="onEnter"
@mouseleave="onLeave"
+ @click="onClick"
>
@@ -56,10 +117,26 @@ function updatePosition() {
diff --git a/frontend/src/components/common/Pagination.vue b/frontend/src/components/common/Pagination.vue
index 2bfc6872..9b4ac200 100644
--- a/frontend/src/components/common/Pagination.vue
+++ b/frontend/src/components/common/Pagination.vue
@@ -123,6 +123,7 @@ import { useI18n } from 'vue-i18n'
import Icon from '@/components/icons/Icon.vue'
import Select from './Select.vue'
import { getConfiguredTablePageSizeOptions, normalizeTablePageSize } from '@/utils/tablePreferences'
+import { setPersistedPageSize } from '@/composables/usePersistedPageSize'
const { t } = useI18n()
@@ -224,6 +225,7 @@ const goToPage = (newPage: number) => {
const handlePageSizeChange = (value: string | number | boolean | null) => {
if (value === null || typeof value === 'boolean') return
const newPageSize = normalizeTablePageSize(typeof value === 'string' ? parseInt(value, 10) : value)
+ setPersistedPageSize(newPageSize)
emit('update:pageSize', newPageSize)
}
diff --git a/frontend/src/components/common/PlatformTypeBadge.vue b/frontend/src/components/common/PlatformTypeBadge.vue
index 1ebc8892..1c7b08c0 100644
--- a/frontend/src/components/common/PlatformTypeBadge.vue
+++ b/frontend/src/components/common/PlatformTypeBadge.vue
@@ -25,6 +25,7 @@
+
{{ typeLabel }}
@@ -88,6 +89,8 @@ const typeLabel = computed(() => {
return 'Key'
case 'bedrock':
return 'AWS'
+ case 'service_account':
+ return 'Vertex'
default:
return props.type
}
diff --git a/frontend/src/components/common/__tests__/HelpTooltip.spec.ts b/frontend/src/components/common/__tests__/HelpTooltip.spec.ts
new file mode 100644
index 00000000..778aabd9
--- /dev/null
+++ b/frontend/src/components/common/__tests__/HelpTooltip.spec.ts
@@ -0,0 +1,80 @@
+import { afterEach, describe, expect, it } from 'vitest'
+import { mount } from '@vue/test-utils'
+import { nextTick } from 'vue'
+import HelpTooltip from '@/components/common/HelpTooltip.vue'
+
+function getTooltipElement(): HTMLDivElement {
+ const tooltip = document.body.querySelector('[role="tooltip"]')
+ if (!(tooltip instanceof HTMLDivElement)) {
+ throw new Error('tooltip element not found')
+ }
+ return tooltip
+}
+
+describe('HelpTooltip', () => {
+ afterEach(() => {
+ document.body.innerHTML = ''
+ })
+
+ it('keeps the existing hover interaction by default', async () => {
+ const wrapper = mount(HelpTooltip, {
+ attachTo: document.body,
+ props: {
+ content: 'hover details',
+ },
+ })
+
+ const trigger = wrapper.get('.group')
+ const tooltip = getTooltipElement()
+
+ expect(tooltip.style.display).toBe('none')
+
+ await trigger.trigger('mouseenter')
+ await nextTick()
+ expect(tooltip.style.display).not.toBe('none')
+
+ await trigger.trigger('mouseleave')
+ await nextTick()
+ expect(tooltip.style.display).toBe('none')
+
+ wrapper.unmount()
+ })
+
+ it('supports click-to-toggle details and closes on outside click', async () => {
+ const wrapper = mount(HelpTooltip, {
+ attachTo: document.body,
+ props: {
+ content: 'click details',
+ trigger: 'click',
+ },
+ })
+
+ const trigger = wrapper.get('.group')
+ const tooltip = getTooltipElement()
+
+ expect(tooltip.style.display).toBe('none')
+
+ await trigger.trigger('click')
+ await nextTick()
+ expect(tooltip.style.display).not.toBe('none')
+ expect(tooltip.textContent).toContain('click details')
+
+ const closeButton = tooltip.querySelector('button[aria-label="Close"]')
+ if (!(closeButton instanceof HTMLButtonElement)) {
+ throw new Error('close button not found')
+ }
+ closeButton.click()
+ await nextTick()
+ expect(tooltip.style.display).toBe('none')
+
+ await trigger.trigger('click')
+ await nextTick()
+ expect(tooltip.style.display).not.toBe('none')
+
+ document.body.dispatchEvent(new MouseEvent('click', { bubbles: true }))
+ await nextTick()
+ expect(tooltip.style.display).toBe('none')
+
+ wrapper.unmount()
+ })
+})
diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue
index 7770e658..94010b62 100644
--- a/frontend/src/components/keys/UseKeyModal.vue
+++ b/frontend/src/components/keys/UseKeyModal.vue
@@ -617,66 +617,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
}
}
const openaiModels = {
- 'gpt-5-codex': {
- name: 'GPT-5 Codex',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {}
- }
- },
- 'gpt-5.1-codex': {
- name: 'GPT-5.1 Codex',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {}
- }
- },
- 'gpt-5.1-codex-max': {
- name: 'GPT-5.1 Codex Max',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {}
- }
- },
- 'gpt-5.1-codex-mini': {
- name: 'GPT-5.1 Codex Mini',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {}
- }
- },
'gpt-5.2': {
name: 'GPT-5.2',
limit: {
@@ -693,6 +633,22 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
xhigh: {}
}
},
+ 'gpt-5.5': {
+ name: 'GPT-5.5',
+ limit: {
+ context: 1050000,
+ output: 128000
+ },
+ options: {
+ store: false
+ },
+ variants: {
+ low: {},
+ medium: {},
+ high: {},
+ xhigh: {}
+ }
+ },
'gpt-5.4': {
name: 'GPT-5.4',
limit: {
@@ -725,22 +681,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
xhigh: {}
}
},
- 'gpt-5.4-nano': {
- name: 'GPT-5.4 Nano',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {},
- xhigh: {}
- }
- },
'gpt-5.3-codex-spark': {
name: 'GPT-5.3 Codex Spark',
limit: {
@@ -773,22 +713,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
xhigh: {}
}
},
- 'gpt-5.2-codex': {
- name: 'GPT-5.2 Codex',
- limit: {
- context: 400000,
- output: 128000
- },
- options: {
- store: false
- },
- variants: {
- low: {},
- medium: {},
- high: {},
- xhigh: {}
- }
- },
'codex-mini-latest': {
name: 'Codex Mini',
limit: {
diff --git a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts
index 98b5dede..f7db586a 100644
--- a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts
+++ b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts
@@ -17,7 +17,7 @@ vi.mock('@/composables/useClipboard', () => ({
import UseKeyModal from '../UseKeyModal.vue'
describe('UseKeyModal', () => {
- it('renders updated GPT-5.4 mini/nano names in OpenCode config', async () => {
+ it('renders GPT-5.4 mini entry in OpenCode config', async () => {
const wrapper = mount(UseKeyModal, {
props: {
show: true,
@@ -48,6 +48,6 @@ describe('UseKeyModal', () => {
const codeBlock = wrapper.find('pre code')
expect(codeBlock.exists()).toBe(true)
expect(codeBlock.text()).toContain('"name": "GPT-5.4 Mini"')
- expect(codeBlock.text()).toContain('"name": "GPT-5.4 Nano"')
+ expect(codeBlock.text()).not.toContain('"name": "GPT-5.4 Nano"')
})
})
diff --git a/frontend/src/components/layout/AppHeader.vue b/frontend/src/components/layout/AppHeader.vue
index fbcab521..306f1429 100644
--- a/frontend/src/components/layout/AppHeader.vue
+++ b/frontend/src/components/layout/AppHeader.vue
@@ -74,10 +74,14 @@
class="flex items-center gap-2 rounded-xl p-1.5 transition-colors hover:bg-gray-100 dark:hover:bg-dark-800"
aria-label="User Menu"
>
-
- {{ userInitials }}
+
+
+
{{ userInitials }}
@@ -232,6 +236,7 @@ const dropdownOpen = ref(false)
const dropdownRef = ref
(null)
const contactInfo = computed(() => appStore.contactInfo)
const docUrl = computed(() => appStore.docUrl)
+const avatarUrl = computed(() => user.value?.avatar_url?.trim() || '')
// 只在标准模式的管理员下显示新手引导按钮
const showOnboardingButton = computed(() => {
diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue
index 92dcc519..d8e2794e 100644
--- a/frontend/src/components/layout/AppSidebar.vue
+++ b/frontend/src/components/layout/AppSidebar.vue
@@ -38,7 +38,7 @@
'sidebar-link-collapsed': sidebarCollapsed
}"
:title="sidebarCollapsed ? item.label : undefined"
- @click="sidebarCollapsed ? undefined : toggleGroup(item)"
+ @click="handleGroupClick(item)"
>
import { computed, h, onMounted, ref, watch } from 'vue'
-import { useRoute } from 'vue-router'
+import { useRoute, useRouter } from 'vue-router'
import { useI18n } from 'vue-i18n'
import { useAdminSettingsStore, useAppStore, useAuthStore, useOnboardingStore } from '@/stores'
import VersionBadge from '@/components/common/VersionBadge.vue'
import { sanitizeSvg } from '@/utils/sanitize'
+import { FeatureFlags, makeSidebarFlag } from '@/utils/featureFlags'
interface NavItem {
path: string
@@ -194,11 +195,39 @@ interface NavItem {
iconSvg?: string
hideInSimpleMode?: boolean
children?: NavItem[]
+ /**
+ * When true, the parent item only toggles the expand/collapse state and
+ * does NOT navigate to its `path`. The `path` is purely a stable key.
+ */
+ expandOnly?: boolean
+ /**
+ * 可选的功能开关 getter。返回 false 时菜单项被隐藏;返回 undefined/true 时显示。
+ * 宽容策略(undefined → 显示)避免 public settings 未加载完成时菜单闪烁消失。
+ * Getter 里访问的 reactive 来源(store / composable)会被 computed 自动追踪,
+ * 开关切换时菜单自动更新。
+ */
+ featureFlag?: () => boolean | undefined
+}
+
+// applyFeatureFlags 递归过滤掉 featureFlag() === false 的节点(含子节点)。
+// 使用 `!== false` 宽容语义:undefined(设置未加载)或 true 都视为显示。
+function applyFeatureFlags(items: NavItem[]): NavItem[] {
+ const out: NavItem[] = []
+ for (const item of items) {
+ if (item.featureFlag && item.featureFlag() === false) continue
+ if (item.children) {
+ out.push({ ...item, children: applyFeatureFlags(item.children) })
+ } else {
+ out.push(item)
+ }
+ }
+ return out
}
const { t } = useI18n()
const route = useRoute()
+const router = useRouter()
const appStore = useAppStore()
const authStore = useAuthStore()
const onboardingStore = useOnboardingStore()
@@ -549,6 +578,41 @@ const ChevronDoubleRightIcon = {
)
}
+const SignalIcon = {
+ render: () =>
+ h(
+ 'svg',
+ { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
+ [
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M9.348 14.651a3.75 3.75 0 010-5.303m5.304 0a3.75 3.75 0 010 5.303m-7.425 2.122a6.75 6.75 0 010-9.546m9.546 0a6.75 6.75 0 010 9.546M5.106 18.894c-3.808-3.807-3.808-9.98 0-13.788m13.788 0c3.808 3.807 3.808 9.98 0 13.788M12 12h.008v.008H12V12zm.375 0a.375.375 0 11-.75 0 .375.375 0 01.75 0z'
+ })
+ ]
+ )
+}
+
+const PriceTagIcon = {
+ render: () =>
+ h(
+ 'svg',
+ { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
+ [
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M9.568 3H5.25A2.25 2.25 0 003 5.25v4.318c0 .597.237 1.17.659 1.591l9.581 9.581c.699.699 1.78.872 2.607.33a18.095 18.095 0 005.223-5.223c.542-.827.369-1.908-.33-2.607L11.16 3.66A2.25 2.25 0 009.568 3z'
+ }),
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M6 6h.008v.008H6V6z'
+ })
+ ]
+ )
+}
+
const ChevronDownIcon = {
render: () =>
h(
@@ -564,72 +628,36 @@ const ChevronDownIcon = {
)
}
-// User navigation items (for regular users)
-const userNavItems = computed((): NavItem[] => {
- const items: NavItem[] = [
- { path: '/dashboard', label: t('nav.dashboard'), icon: DashboardIcon },
- { path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon },
- { path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true },
- { path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
- ...(appStore.cachedPublicSettings?.payment_enabled
- ? [
- {
- path: '/purchase',
- label: t('nav.buySubscription'),
- icon: RechargeSubscriptionIcon,
- hideInSimpleMode: true
- },
- ]
- : []),
- ...(appStore.cachedPublicSettings?.payment_enabled
- ? [
- {
- path: '/orders',
- label: t('nav.myOrders'),
- icon: OrderListIcon,
- hideInSimpleMode: true
- },
- ]
- : []),
- { path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true },
- { path: '/profile', label: t('nav.profile'), icon: UserIcon },
- ...customMenuItemsForUser.value.map((item): NavItem => ({
- path: `/custom/${item.id}`,
- label: item.label,
- icon: null,
- iconSvg: item.icon_svg,
- })),
- ]
- return authStore.isSimpleMode ? items.filter(item => !item.hideInSimpleMode) : items
-})
+// Public-settings flags go through the registry in utils/featureFlags.ts,
+// which handles the opt-in vs opt-out fallback when settings haven't loaded
+// yet. Admin-only flags (not in public settings) stay inline below.
+const flagChannelMonitor = makeSidebarFlag(FeatureFlags.channelMonitor)
+const flagPayment = makeSidebarFlag(FeatureFlags.payment)
+const flagAvailableChannels = makeSidebarFlag(FeatureFlags.availableChannels)
+const flagAffiliate = makeSidebarFlag(FeatureFlags.affiliate)
+const flagOpsMonitoring = () => adminSettingsStore.opsMonitoringEnabled
+const flagAdminPayment = () => adminSettingsStore.paymentEnabled
-// Personal navigation items (for admin's "My Account" section, without Dashboard)
-const personalNavItems = computed((): NavItem[] => {
- const items: NavItem[] = [
+// buildSelfNavItems 构造用户自己的导航项(用户端主菜单和管理员的"我的账户"子菜单共享这组声明)。
+// withDashboard=true 时包含仪表盘(用户端),false 时不含(管理员的个人区已经有独立仪表盘入口)。
+//
+// 条目顺序:密钥 → 用量 → 可用渠道 → 渠道状态 → 订阅/支付 → 兑换/资料。
+// 可用渠道紧挨渠道状态之上,让用户"先看自己能用什么、再看对应状态"。
+function buildSelfNavItems(withDashboard: boolean): NavItem[] {
+ const items: NavItem[] = []
+ if (withDashboard) {
+ items.push({ path: '/dashboard', label: t('nav.dashboard'), icon: DashboardIcon })
+ }
+ items.push(
{ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon },
{ path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true },
+ { path: '/available-channels', label: t('nav.availableChannels'), icon: ChannelIcon, hideInSimpleMode: true, featureFlag: flagAvailableChannels },
+ { path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon, featureFlag: flagChannelMonitor },
{ path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
- ...(appStore.cachedPublicSettings?.payment_enabled
- ? [
- {
- path: '/purchase',
- label: t('nav.buySubscription'),
- icon: RechargeSubscriptionIcon,
- hideInSimpleMode: true
- },
- ]
- : []),
- ...(appStore.cachedPublicSettings?.payment_enabled
- ? [
- {
- path: '/orders',
- label: t('nav.myOrders'),
- icon: OrderListIcon,
- hideInSimpleMode: true
- },
- ]
- : []),
+ { path: '/purchase', label: t('nav.buySubscription'), icon: RechargeSubscriptionIcon, hideInSimpleMode: true, featureFlag: flagPayment },
+ { path: '/orders', label: t('nav.myOrders'), icon: OrderListIcon, hideInSimpleMode: true, featureFlag: flagPayment },
{ path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true },
+ { path: '/affiliate', label: t('nav.affiliate'), icon: UsersIcon, hideInSimpleMode: true, featureFlag: flagAffiliate },
{ path: '/profile', label: t('nav.profile'), icon: UserIcon },
...customMenuItemsForUser.value.map((item): NavItem => ({
path: `/custom/${item.id}`,
@@ -637,9 +665,23 @@ const personalNavItems = computed((): NavItem[] => {
icon: null,
iconSvg: item.icon_svg,
})),
- ]
- return authStore.isSimpleMode ? items.filter(item => !item.hideInSimpleMode) : items
-})
+ )
+ return items
+}
+
+// finalizeNav 合并三重过滤:featureFlag 过滤 + simple 模式过滤。
+function finalizeNav(items: NavItem[]): NavItem[] {
+ const visible = applyFeatureFlags(items)
+ return authStore.isSimpleMode ? visible.filter(item => !item.hideInSimpleMode) : visible
+}
+
+// User navigation items (for regular users)
+const userNavItems = computed((): NavItem[] => finalizeNav(buildSelfNavItems(true)))
+
+// Personal navigation items (for admin's "My Account" section, without Dashboard).
+// Admins access 可用渠道 from this section just like regular users — there is no
+// separate admin entry, since the page is purely a user-facing view.
+const personalNavItems = computed((): NavItem[] => finalizeNav(buildSelfNavItems(false)))
// Custom menu items filtered by visibility
const customMenuItemsForUser = computed(() => {
@@ -659,54 +701,60 @@ const customMenuItemsForAdmin = computed(() => {
const adminNavItems = computed((): NavItem[] => {
const baseItems: NavItem[] = [
{ path: '/admin/dashboard', label: t('nav.dashboard'), icon: DashboardIcon },
- ...(adminSettingsStore.opsMonitoringEnabled
- ? [{ path: '/admin/ops', label: t('nav.ops'), icon: ChartIcon }]
- : []),
+ { path: '/admin/ops', label: t('nav.ops'), icon: ChartIcon, featureFlag: flagOpsMonitoring },
{ path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true },
{ path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true },
- { path: '/admin/channels', label: t('nav.channels', '渠道管理'), icon: ChannelIcon, hideInSimpleMode: true },
+ {
+ path: '/admin/channels',
+ label: t('nav.channelManagement'),
+ icon: ChannelIcon,
+ hideInSimpleMode: true,
+ expandOnly: true,
+ children: [
+ { path: '/admin/channels/pricing', label: t('nav.channelPricing'), icon: PriceTagIcon },
+ { path: '/admin/channels/monitor', label: t('nav.channelMonitor'), icon: SignalIcon, featureFlag: flagChannelMonitor },
+ ],
+ },
{ path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
{ path: '/admin/accounts', label: t('nav.accounts'), icon: GlobeIcon },
{ path: '/admin/announcements', label: t('nav.announcements'), icon: BellIcon },
{ path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon },
{ path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true },
{ path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true },
- ...(adminSettingsStore.paymentEnabled
- ? [
- {
- path: '/admin/orders',
- label: t('nav.orderManagement'),
- icon: OrderIcon,
- hideInSimpleMode: true,
- children: [
- { path: '/admin/orders/dashboard', label: t('nav.paymentDashboard'), icon: ChartIcon },
- { path: '/admin/orders', label: t('nav.orderManagement'), icon: OrderIcon },
- { path: '/admin/orders/plans', label: t('nav.paymentPlans'), icon: CreditCardIcon },
- ],
- },
- ]
- : []),
+ {
+ path: '/admin/orders',
+ label: t('nav.orderManagement'),
+ icon: OrderIcon,
+ hideInSimpleMode: true,
+ expandOnly: true,
+ featureFlag: flagAdminPayment,
+ children: [
+ { path: '/admin/orders/dashboard', label: t('nav.paymentDashboard'), icon: ChartIcon },
+ { path: '/admin/orders', label: t('nav.orderManagement'), icon: OrderIcon },
+ { path: '/admin/orders/plans', label: t('nav.paymentPlans'), icon: CreditCardIcon },
+ ],
+ },
{ path: '/admin/usage', label: t('nav.usage'), icon: ChartIcon }
]
+ const visible = applyFeatureFlags(baseItems)
+
// 简单模式下,在系统设置前插入 API密钥
if (authStore.isSimpleMode) {
- const filtered = baseItems.filter(item => !item.hideInSimpleMode)
+ const filtered = visible.filter(item => !item.hideInSimpleMode)
filtered.push({ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon })
filtered.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon })
- // Add admin custom menu items after settings
for (const cm of customMenuItemsForAdmin.value) {
filtered.push({ path: `/custom/${cm.id}`, label: cm.label, icon: null, iconSvg: cm.icon_svg })
}
return filtered
}
- baseItems.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon })
- // Add admin custom menu items after settings
+ visible.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon })
for (const cm of customMenuItemsForAdmin.value) {
- baseItems.push({ path: `/custom/${cm.id}`, label: cm.label, icon: null, iconSvg: cm.icon_svg })
+ visible.push({ path: `/custom/${cm.id}`, label: cm.label, icon: null, iconSvg: cm.icon_svg })
}
- return baseItems
+ return visible
})
function toggleSidebar() {
@@ -764,6 +812,28 @@ function toggleGroup(item: NavItem) {
}
}
+/**
+ * Click handler for collapsible parent items.
+ * - When sidebar is collapsed: do nothing (children are not visible).
+ * - When `expandOnly` is true: only toggle expand state.
+ * - Otherwise (default, e.g. /admin/orders): navigate to the parent path
+ * (router-link semantics) and ensure the group is expanded.
+ */
+function handleGroupClick(item: NavItem) {
+ if (sidebarCollapsed.value) return
+ if (item.expandOnly) {
+ toggleGroup(item)
+ return
+ }
+ // Push to path and ensure expanded
+ if (route.path !== item.path) {
+ router.push(item.path)
+ }
+ if (!expandedGroups.value.has(item.path)) {
+ expandedGroups.value.add(item.path)
+ }
+}
+
// Initialize theme
const savedTheme = localStorage.getItem('theme')
if (
diff --git a/frontend/src/components/layout/__tests__/AppSidebar.spec.ts b/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
index 118c7615..592ce8a3 100644
--- a/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
+++ b/frontend/src/components/layout/__tests__/AppSidebar.spec.ts
@@ -21,7 +21,7 @@ describe('AppSidebar custom SVG styles', () => {
describe('AppSidebar header styles', () => {
it('does not clip the version badge dropdown', () => {
- const sidebarHeaderBlockMatch = styleSource.match(/\.sidebar-header\s*\{[\s\S]*?\n \}/)
+ const sidebarHeaderBlockMatch = styleSource.match(/\.sidebar-header\s*\{[\s\S]*?\n {2}\}/)
const sidebarBrandBlockMatch = componentSource.match(/\.sidebar-brand\s*\{[\s\S]*?\n\}/)
expect(sidebarHeaderBlockMatch).not.toBeNull()
diff --git a/frontend/src/components/payment/PaymentProviderDialog.vue b/frontend/src/components/payment/PaymentProviderDialog.vue
index 10c1bfea..0e03ec60 100644
--- a/frontend/src/components/payment/PaymentProviderDialog.vue
+++ b/frontend/src/components/payment/PaymentProviderDialog.vue
@@ -73,9 +73,42 @@
-
- {{ t('admin.settings.payment.providerConfig') }}
-
+
+
+ {{ t('admin.settings.payment.providerConfig') }}
+
+
+
+
+ ?
+
+
+
+
{{ paymentGuide.summary }}
+
+
{{ item.title }}
+
{{ t('admin.settings.payment.guideOpenLabel') }} {{ item.open }}
+
{{ t('admin.settings.payment.guideCallLabel') }} {{ item.call }}
+
{{ t('admin.settings.payment.guideFallbackLabel') }} {{ item.fallback }}
+
+
+ {{ paymentGuide.note }}
+
+
+
+
+
+ {{ paymentGuide.summary }}
+
@@ -88,13 +121,24 @@
v-model="config[field.key]"
rows="3"
class="input font-mono text-xs"
+ autocomplete="new-password"
+ data-1p-ignore
+ data-lpignore="true"
+ data-bwignore="true"
+ spellcheck="false"
+ :placeholder="editing ? t('admin.accounts.leaveEmptyToKeep') : ''"
/>
{
}))
})
+const paymentGuide = computed(() => {
+ if (form.provider_key === 'alipay') {
+ return {
+ summary: t('admin.settings.payment.alipayGuideSummary'),
+ items: [
+ {
+ title: t('admin.settings.payment.alipayGuideFaceToFaceTitle'),
+ open: t('admin.settings.payment.alipayGuideFaceToFaceOpen'),
+ call: t('admin.settings.payment.alipayGuideFaceToFaceCall'),
+ fallback: t('admin.settings.payment.alipayGuideFaceToFaceFallback'),
+ },
+ {
+ title: t('admin.settings.payment.alipayGuidePagePayTitle'),
+ open: t('admin.settings.payment.alipayGuidePagePayOpen'),
+ call: t('admin.settings.payment.alipayGuidePagePayCall'),
+ fallback: t('admin.settings.payment.alipayGuidePagePayFallback'),
+ },
+ {
+ title: t('admin.settings.payment.alipayGuideWapTitle'),
+ open: t('admin.settings.payment.alipayGuideWapOpen'),
+ call: t('admin.settings.payment.alipayGuideWapCall'),
+ fallback: t('admin.settings.payment.alipayGuideWapFallback'),
+ },
+ ],
+ }
+ }
+
+ if (form.provider_key === 'wxpay') {
+ return {
+ summary: t('admin.settings.payment.wxpayGuideSummary'),
+ note: t('admin.settings.payment.wxpayGuideNote'),
+ items: [
+ {
+ title: t('admin.settings.payment.wxpayGuideNativeTitle'),
+ open: t('admin.settings.payment.wxpayGuideNativeOpen'),
+ call: t('admin.settings.payment.wxpayGuideNativeCall'),
+ fallback: t('admin.settings.payment.wxpayGuideNativeFallback'),
+ },
+ {
+ title: t('admin.settings.payment.wxpayGuideJsapiTitle'),
+ open: t('admin.settings.payment.wxpayGuideJsapiOpen'),
+ call: t('admin.settings.payment.wxpayGuideJsapiCall'),
+ fallback: t('admin.settings.payment.wxpayGuideJsapiFallback'),
+ },
+ {
+ title: t('admin.settings.payment.wxpayGuideH5Title'),
+ open: t('admin.settings.payment.wxpayGuideH5Open'),
+ call: t('admin.settings.payment.wxpayGuideH5Call'),
+ fallback: t('admin.settings.payment.wxpayGuideH5Fallback'),
+ },
+ ],
+ }
+ }
+
+ return null
+})
+
const limitableTypes = computed(() => {
// Stripe: single "stripe" entry (one set of shared limits)
if (form.provider_key === 'stripe') {
@@ -398,9 +513,12 @@ function handleSave() {
emitValidationError(t('admin.settings.payment.validationNameRequired'))
return
}
- // Validate required config fields — all non-optional fields must be filled
+ // Validate required config fields — all non-optional fields must be filled.
+ // In edit mode, sensitive fields may be left blank to preserve the stored
+ // value (backend merges blanks by preserving the existing secret).
for (const f of PROVIDER_CONFIG_FIELDS[form.provider_key] || []) {
if (f.optional) continue
+ if (props.editing && f.sensitive) continue
const val = (config[f.key] || '').trim()
if (!val) {
const label = f.label || t(`admin.settings.payment.field_${f.key}`)
@@ -412,8 +530,6 @@ function handleSave() {
const filteredConfig: Record = {}
for (const [k, v] of Object.entries(config)) {
if (!v || !v.trim()) continue
- // Skip masked values — backend keeps existing credentials
- if (v === '••••••••') continue
filteredConfig[k] = v
}
@@ -470,7 +586,8 @@ function loadProvider(provider: ProviderInstance) {
form.refund_enabled = provider.refund_enabled
form.allow_user_refund = provider.allow_user_refund
clearConfig()
- // Pre-fill config from API response (non-sensitive in cleartext, sensitive masked as ••••••••)
+ // Pre-fill config from API response. Backend omits sensitive fields entirely,
+ // so those inputs stay blank — submitting blank preserves the stored secret.
if (provider.config) {
for (const [k, v] of Object.entries(provider.config)) {
// Skip notifyUrl/returnUrl — they are derived from callbackBaseUrl
diff --git a/frontend/src/components/payment/PaymentQRDialog.vue b/frontend/src/components/payment/PaymentQRDialog.vue
index b9026e78..09d273cc 100644
--- a/frontend/src/components/payment/PaymentQRDialog.vue
+++ b/frontend/src/components/payment/PaymentQRDialog.vue
@@ -78,8 +78,8 @@ import Icon from '@/components/icons/Icon.vue'
import { usePaymentStore } from '@/stores/payment'
import { useAppStore } from '@/stores'
import { paymentAPI } from '@/api/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
-import { POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { extractI18nErrorMessage } from '@/utils/apiError'
+import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { PaymentOrder } from '@/types/payment'
import QRCode from 'qrcode'
import alipayIcon from '@/assets/icons/alipay.svg'
@@ -147,7 +147,7 @@ function getLogoForType(): string | null {
function reopenPopup() {
if (props.payUrl) {
- window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ window.open(props.payUrl, 'paymentPopup', getPaymentPopupFeatures())
}
}
@@ -222,7 +222,7 @@ async function handleCancel() {
cleanup()
emit('close')
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
cancelling.value = false
}
diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue
index 974dee66..2cdc9dce 100644
--- a/frontend/src/components/payment/PaymentStatusPanel.vue
+++ b/frontend/src/components/payment/PaymentStatusPanel.vue
@@ -84,6 +84,9 @@
{{ scanHint }}
+
+ {{ t('payment.qr.openPayWindow') }}
+
@@ -124,8 +127,8 @@ import { useI18n } from 'vue-i18n'
import { usePaymentStore } from '@/stores/payment'
import { useAppStore } from '@/stores'
import { paymentAPI } from '@/api/payment'
-import { extractApiErrorMessage } from '@/utils/apiError'
-import { POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { extractI18nErrorMessage } from '@/utils/apiError'
+import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { PaymentOrder } from '@/types/payment'
import Icon from '@/components/icons/Icon.vue'
import QRCode from 'qrcode'
@@ -141,7 +144,9 @@ const props = defineProps<{
orderType?: string
}>()
-const emit = defineEmits<{ done: []; success: [] }>()
+type PaymentOutcome = 'success' | 'cancelled' | 'expired'
+
+const emit = defineEmits<{ done: []; success: []; settled: [outcome: PaymentOutcome] }>()
const { t } = useI18n()
const paymentStore = usePaymentStore()
@@ -154,7 +159,7 @@ const cancelling = ref(false)
const paidOrder = ref
(null)
// Terminal outcome: null = still active, 'success' | 'cancelled' | 'expired'
-const outcome = ref<'success' | 'cancelled' | 'expired' | null>(null)
+const outcome = ref(null)
let pollTimer: ReturnType | null = null
let countdownTimer: ReturnType | null = null
@@ -192,12 +197,25 @@ const countdownDisplay = computed(() => {
return m.toString().padStart(2, '0') + ':' + s.toString().padStart(2, '0')
})
+function isSuccessStatus(status: string | null | undefined): boolean {
+ return status === 'COMPLETED' || status === 'PAID' || status === 'RECHARGING'
+}
+
function reopenPopup() {
if (props.payUrl) {
- window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES)
+ const win = window.open(props.payUrl, 'paymentPopup', getPaymentPopupFeatures())
+ if (!win || win.closed) {
+ window.location.href = props.payUrl
+ }
}
}
+function setOutcome(next: PaymentOutcome) {
+ if (outcome.value === next) return
+ outcome.value = next
+ emit('settled', next)
+}
+
async function renderQR() {
await nextTick()
if (!qrCanvas.value || !qrUrl.value) return
@@ -211,26 +229,26 @@ async function pollStatus() {
if (!props.orderId || outcome.value) return
const order = await paymentStore.pollOrderStatus(props.orderId)
if (!order) return
- if (order.status === 'COMPLETED' || order.status === 'PAID') {
+ if (isSuccessStatus(order.status)) {
cleanup()
paidOrder.value = order
- outcome.value = 'success'
+ setOutcome('success')
emit('success')
} else if (order.status === 'CANCELLED') {
cleanup()
- outcome.value = 'cancelled'
+ setOutcome('cancelled')
} else if (order.status === 'EXPIRED' || order.status === 'FAILED') {
cleanup()
- outcome.value = 'expired'
+ setOutcome('expired')
}
}
function startCountdown(seconds: number) {
remainingSeconds.value = Math.max(0, seconds)
- if (remainingSeconds.value <= 0) { outcome.value = 'expired'; return }
+ if (remainingSeconds.value <= 0) { setOutcome('expired'); return }
countdownTimer = setInterval(() => {
remainingSeconds.value--
- if (remainingSeconds.value <= 0) { outcome.value = 'expired'; cleanup() }
+ if (remainingSeconds.value <= 0) { setOutcome('expired'); cleanup() }
}, 1000)
}
@@ -240,9 +258,9 @@ async function handleCancel() {
try {
await paymentAPI.cancelOrder(props.orderId)
cleanup()
- outcome.value = 'cancelled'
+ setOutcome('cancelled')
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
cancelling.value = false
}
diff --git a/frontend/src/components/payment/StripePaymentInline.vue b/frontend/src/components/payment/StripePaymentInline.vue
index b8fd55ef..bdb0dd6b 100644
--- a/frontend/src/components/payment/StripePaymentInline.vue
+++ b/frontend/src/components/payment/StripePaymentInline.vue
@@ -67,10 +67,10 @@
import { ref, onMounted, nextTick } from 'vue'
import { useI18n } from 'vue-i18n'
import { useRouter } from 'vue-router'
-import { extractApiErrorMessage } from '@/utils/apiError'
+import { extractI18nErrorMessage } from '@/utils/apiError'
import { paymentAPI } from '@/api/payment'
import { useAppStore } from '@/stores'
-import { STRIPE_POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig'
+import { getPaymentPopupFeatures } from '@/components/payment/providerConfig'
import type { Stripe, StripeElements } from '@stripe/stripe-js'
import Icon from '@/components/icons/Icon.vue'
@@ -132,7 +132,7 @@ onMounted(async () => {
selectedType.value = event.value.type
})
} catch (err: unknown) {
- initError.value = extractApiErrorMessage(err, t('payment.stripeLoadFailed'))
+ initError.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.stripeLoadFailed'))
} finally {
loading.value = false
}
@@ -151,7 +151,7 @@ async function handlePay() {
amount: String(props.payAmount),
},
}).href
- const popup = window.open(popupUrl, 'paymentPopup', STRIPE_POPUP_WINDOW_FEATURES)
+ const popup = window.open(popupUrl, 'paymentPopup', getPaymentPopupFeatures())
const onReady = (event: MessageEvent) => {
if (event.source !== popup || event.data?.type !== 'STRIPE_POPUP_READY') return
@@ -186,7 +186,7 @@ async function handlePay() {
emit('success')
}
} catch (err: unknown) {
- error.value = extractApiErrorMessage(err, t('payment.result.failed'))
+ error.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.result.failed'))
} finally {
submitting.value = false
}
@@ -199,7 +199,7 @@ async function handleCancel() {
await paymentAPI.cancelOrder(props.orderId)
emit('back')
} catch (err: unknown) {
- appStore.showError(extractApiErrorMessage(err, t('common.error')))
+ appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error')))
} finally {
cancelling.value = false
}
diff --git a/frontend/src/components/payment/__tests__/PaymentProviderDialog.spec.ts b/frontend/src/components/payment/__tests__/PaymentProviderDialog.spec.ts
new file mode 100644
index 00000000..637d805f
--- /dev/null
+++ b/frontend/src/components/payment/__tests__/PaymentProviderDialog.spec.ts
@@ -0,0 +1,78 @@
+import { describe, expect, it, vi } from 'vitest'
+import { mount } from '@vue/test-utils'
+import { nextTick } from 'vue'
+import PaymentProviderDialog from '@/components/payment/PaymentProviderDialog.vue'
+
+const messages: Record = {
+ 'admin.settings.payment.providerConfig': 'Credentials',
+ 'admin.settings.payment.paymentGuideTrigger': 'View payment guide',
+ 'admin.settings.payment.alipayGuideSummary': 'Desktop prefers QR precreate and falls back to cashier; mobile prefers WAP checkout.',
+ 'admin.settings.payment.wxpayGuideSummary': 'Desktop prefers Native QR; mobile routes to JSAPI or H5 based on browser context.',
+}
+
+vi.mock('vue-i18n', () => ({
+ useI18n: () => ({
+ t: (key: string) => messages[key] ?? key,
+ }),
+}))
+
+function mountDialog() {
+ return mount(PaymentProviderDialog, {
+ props: {
+ show: true,
+ saving: false,
+ editing: null,
+ allKeyOptions: [
+ { value: 'alipay', label: 'Alipay' },
+ { value: 'wxpay', label: 'WeChat Pay' },
+ { value: 'stripe', label: 'Stripe' },
+ ],
+ enabledKeyOptions: [
+ { value: 'alipay', label: 'Alipay' },
+ { value: 'wxpay', label: 'WeChat Pay' },
+ ],
+ allPaymentTypes: [
+ { value: 'alipay', label: 'Alipay' },
+ { value: 'wxpay', label: 'WeChat Pay' },
+ ],
+ redirectLabel: 'Redirect',
+ },
+ global: {
+ stubs: {
+ BaseDialog: {
+ template: '
',
+ },
+ Select: {
+ props: ['modelValue', 'options', 'disabled'],
+ template: '
',
+ },
+ ToggleSwitch: {
+ template: '
',
+ },
+ },
+ },
+ })
+}
+
+describe('PaymentProviderDialog payment guide', () => {
+ it('shows no payment guide for providers without a flow guide', () => {
+ const wrapper = mountDialog()
+
+ expect(wrapper.text()).not.toContain(messages['admin.settings.payment.alipayGuideSummary'])
+ expect(wrapper.text()).not.toContain(messages['admin.settings.payment.wxpayGuideSummary'])
+ expect(wrapper.find('button[title="View payment guide"]').exists()).toBe(false)
+ })
+
+ it.each([
+ ['alipay', 'admin.settings.payment.alipayGuideSummary'],
+ ['wxpay', 'admin.settings.payment.wxpayGuideSummary'],
+ ])('shows the payment guide summary for %s', async (providerKey, summaryKey) => {
+ const wrapper = mountDialog()
+
+ ;(wrapper.vm as unknown as { reset: (key: string) => void }).reset(providerKey)
+ await nextTick()
+
+ expect(wrapper.text()).toContain(messages[summaryKey])
+ expect(wrapper.find('button[title="View payment guide"]').exists()).toBe(true)
+ })
+})
diff --git a/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts b/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts
new file mode 100644
index 00000000..ea2b6377
--- /dev/null
+++ b/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts
@@ -0,0 +1,131 @@
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+import { flushPromises, mount } from '@vue/test-utils'
+
+const pollOrderStatus = vi.hoisted(() => vi.fn())
+const cancelOrder = vi.hoisted(() => vi.fn())
+const showError = vi.hoisted(() => vi.fn())
+const toCanvas = vi.hoisted(() => vi.fn())
+
+vi.mock('vue-i18n', async () => {
+ const actual = await vi.importActual('vue-i18n')
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key,
+ }),
+ }
+})
+
+vi.mock('@/stores/payment', () => ({
+ usePaymentStore: () => ({
+ pollOrderStatus,
+ }),
+}))
+
+vi.mock('@/stores', () => ({
+ useAppStore: () => ({
+ showError,
+ }),
+}))
+
+vi.mock('@/api/payment', () => ({
+ paymentAPI: {
+ cancelOrder,
+ },
+}))
+
+vi.mock('qrcode', () => ({
+ default: {
+ toCanvas,
+ },
+}))
+
+import PaymentStatusPanel from '../PaymentStatusPanel.vue'
+
+const orderFactory = (status: string) => ({
+ id: 42,
+ user_id: 9,
+ amount: 88,
+ pay_amount: 88,
+ fee_rate: 0,
+ payment_type: 'alipay',
+ out_trade_no: 'sub2_20260420abcd1234',
+ status,
+ order_type: 'balance',
+ created_at: '2026-04-20T12:00:00Z',
+ expires_at: '2099-01-01T12:30:00Z',
+ refund_amount: 0,
+})
+
+describe('PaymentStatusPanel', () => {
+ beforeEach(() => {
+ vi.useFakeTimers()
+ pollOrderStatus.mockReset()
+ cancelOrder.mockReset()
+ showError.mockReset()
+ toCanvas.mockReset().mockResolvedValue(undefined)
+ })
+
+ afterEach(() => {
+ vi.useRealTimers()
+ })
+
+ it('treats RECHARGING as a successful terminal state', async () => {
+ pollOrderStatus.mockResolvedValue(orderFactory('RECHARGING'))
+
+ const wrapper = mount(PaymentStatusPanel, {
+ props: {
+ orderId: 42,
+ qrCode: 'https://pay.example.com/qr/42',
+ expiresAt: '2099-01-01T12:30:00Z',
+ paymentType: 'alipay',
+ orderType: 'balance',
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ },
+ },
+ })
+
+ await flushPromises()
+ await vi.advanceTimersByTimeAsync(3000)
+ await flushPromises()
+
+ expect(pollOrderStatus).toHaveBeenCalledWith(42)
+ expect(wrapper.text()).toContain('payment.result.success')
+ expect(wrapper.emitted('success')).toHaveLength(1)
+ })
+
+ it('shows reopen button in QR mode when payUrl is also available', async () => {
+ const openSpy = vi.spyOn(window, 'open').mockReturnValue({ closed: false } as Window)
+
+ const wrapper = mount(PaymentStatusPanel, {
+ props: {
+ orderId: 42,
+ qrCode: 'https://pay.example.com/qr/42',
+ payUrl: 'https://pay.example.com/session/42',
+ expiresAt: '2099-01-01T12:30:00Z',
+ paymentType: 'alipay',
+ orderType: 'balance',
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ },
+ },
+ })
+
+ await flushPromises()
+ expect(wrapper.text()).toContain('payment.qr.openPayWindow')
+
+ await wrapper.get('button.btn.btn-secondary.text-sm').trigger('click')
+ expect(openSpy).toHaveBeenCalledWith(
+ 'https://pay.example.com/session/42',
+ 'paymentPopup',
+ expect.any(String),
+ )
+
+ openSpy.mockRestore()
+ })
+})
diff --git a/frontend/src/components/payment/__tests__/paymentFlow.spec.ts b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts
new file mode 100644
index 00000000..aebec8e5
--- /dev/null
+++ b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts
@@ -0,0 +1,320 @@
+import { describe, expect, it } from 'vitest'
+import type { CreateOrderResult, MethodLimit } from '@/types/payment'
+import {
+ buildCreateOrderPayload,
+ decidePaymentLaunch,
+ getVisibleMethods,
+ readPaymentRecoverySnapshot,
+ type PaymentRecoverySnapshot,
+} from '@/components/payment/paymentFlow'
+
+function methodLimit(overrides: Partial = {}): MethodLimit {
+ return {
+ daily_limit: 0,
+ daily_used: 0,
+ daily_remaining: 0,
+ single_min: 0,
+ single_max: 0,
+ fee_rate: 0,
+ available: true,
+ ...overrides,
+ }
+}
+
+function createOrderResult(overrides: Partial = {}): CreateOrderResult {
+ return {
+ order_id: 101,
+ amount: 88,
+ pay_amount: 88,
+ fee_rate: 0,
+ expires_at: '2099-01-01T00:10:00.000Z',
+ ...overrides,
+ }
+}
+
+describe('getVisibleMethods', () => {
+ it('normalizes provider aliases and keeps stripe as a top-level method', () => {
+ const visible = getVisibleMethods({
+ alipay_direct: methodLimit({ single_min: 5 }),
+ wxpay: methodLimit({ single_max: 100 }),
+ stripe: methodLimit({ fee_rate: 3 }),
+ })
+
+ expect(visible).toEqual({
+ alipay: methodLimit({ single_min: 5 }),
+ wxpay: methodLimit({ single_max: 100 }),
+ stripe: methodLimit({ fee_rate: 3 }),
+ })
+ })
+
+ it('prefers canonical visible methods over aliases when both exist', () => {
+ const visible = getVisibleMethods({
+ alipay: methodLimit({ single_min: 2 }),
+ alipay_direct: methodLimit({ single_min: 9 }),
+ wxpay_direct: methodLimit({ fee_rate: 1.2 }),
+ })
+
+ expect(visible.alipay.single_min).toBe(2)
+ expect(visible.wxpay.fee_rate).toBe(1.2)
+ })
+})
+
+describe('decidePaymentLaunch', () => {
+ it('uses Stripe popup waiting flow for desktop Alipay client secret', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ client_secret: 'cs_test',
+ resume_token: 'resume-1',
+ }), {
+ visibleMethod: 'alipay',
+ orderType: 'balance',
+ isMobile: false,
+ })
+
+ expect(decision.kind).toBe('stripe_popup')
+ expect(decision.paymentState.paymentType).toBe('alipay')
+ expect(decision.stripeMethod).toBe('alipay')
+ expect(decision.recovery.resumeToken).toBe('resume-1')
+ expect(decision.recovery.outTradeNo).toBe('')
+ })
+
+ it('routes Stripe button click to the full Payment Element without a preselected sub-method', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ client_secret: 'cs_test',
+ }), {
+ visibleMethod: 'stripe',
+ orderType: 'balance',
+ isMobile: false,
+ })
+
+ expect(decision.kind).toBe('stripe_route')
+ expect(decision.stripeMethod).toBeUndefined()
+ })
+
+ it('uses Stripe route flow for mobile WeChat client secret', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ client_secret: 'cs_test',
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'subscription',
+ isMobile: true,
+ })
+
+ expect(decision.kind).toBe('stripe_route')
+ expect(decision.stripeMethod).toBe('wechat_pay')
+ expect(decision.paymentState.orderType).toBe('subscription')
+ })
+
+ it('keeps hosted redirect metadata for recovery flows', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ pay_url: 'https://pay.example.com/session/abc',
+ payment_mode: 'popup',
+ resume_token: 'resume-2',
+ out_trade_no: 'sub2_abc',
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'balance',
+ isMobile: false,
+ })
+
+ expect(decision.kind).toBe('redirect_waiting')
+ expect(decision.paymentState.payUrl).toBe('https://pay.example.com/session/abc')
+ expect(decision.recovery.paymentMode).toBe('popup')
+ expect(decision.recovery.outTradeNo).toBe('sub2_abc')
+ expect(decision.recovery.resumeToken).toBe('resume-2')
+ })
+
+ it('prefers redirect on mobile when both pay_url and qr_code are present', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ pay_url: 'https://pay.example.com/mobile/session',
+ qr_code: 'https://pay.example.com/qr/session',
+ }), {
+ visibleMethod: 'alipay',
+ orderType: 'balance',
+ isMobile: true,
+ })
+
+ expect(decision.kind).toBe('redirect_waiting')
+ expect(decision.paymentState.payUrl).toBe('https://pay.example.com/mobile/session')
+ expect(decision.paymentState.qrCode).toBe('https://pay.example.com/qr/session')
+ })
+
+ it('keeps QR flow on desktop when both pay_url and qr_code are present', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ pay_url: 'https://pay.example.com/desktop/session',
+ qr_code: 'https://pay.example.com/qr/session',
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'balance',
+ isMobile: false,
+ })
+
+ expect(decision.kind).toBe('qr_waiting')
+ expect(decision.paymentState.qrCode).toBe('https://pay.example.com/qr/session')
+ })
+
+ it('returns wechat oauth launch when backend requires in-app authorization', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ result_type: 'oauth_required',
+ payment_type: 'wxpay',
+ oauth: {
+ authorize_url: '/api/v1/auth/oauth/wechat/payment/start?payment_type=wxpay',
+ appid: 'wx123',
+ scope: 'snsapi_base',
+ redirect_url: '/auth/wechat/payment/callback',
+ },
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'balance',
+ isMobile: true,
+ })
+
+ expect(decision.kind).toBe('wechat_oauth')
+ expect(decision.oauth?.authorize_url).toContain('/api/v1/auth/oauth/wechat/payment/start')
+ expect(decision.paymentState.paymentType).toBe('wxpay')
+ })
+
+ it('returns wechat jsapi launch when backend has a jsapi payload ready', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ result_type: 'jsapi_ready',
+ payment_type: 'wxpay',
+ jsapi: {
+ appId: 'wx123',
+ timeStamp: '1712345678',
+ nonceStr: 'nonce-123',
+ package: 'prepay_id=wx123',
+ signType: 'RSA',
+ paySign: 'signed-payload',
+ },
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'subscription',
+ isMobile: true,
+ })
+
+ expect(decision.kind).toBe('wechat_jsapi')
+ expect(decision.jsapi?.appId).toBe('wx123')
+ expect(decision.paymentState.orderType).toBe('subscription')
+ })
+})
+
+describe('buildCreateOrderPayload', () => {
+ it('normalizes visible method aliases and attaches a canonical result URL', () => {
+ expect(buildCreateOrderPayload({
+ amount: 88,
+ paymentType: 'alipay_direct',
+ orderType: 'balance',
+ origin: 'https://app.example.com/',
+ isMobile: true,
+ isWechatBrowser: false,
+ })).toEqual({
+ amount: 88,
+ payment_type: 'alipay',
+ order_type: 'balance',
+ return_url: 'https://app.example.com/payment/result',
+ is_mobile: true,
+ payment_source: 'hosted_redirect',
+ })
+ })
+
+ it('uses WeChat in-app resume source for visible WeChat payments in the WeChat browser', () => {
+ expect(buildCreateOrderPayload({
+ amount: 128,
+ paymentType: 'wxpay',
+ orderType: 'subscription',
+ planId: 7,
+ origin: 'https://app.example.com',
+ isMobile: false,
+ isWechatBrowser: true,
+ })).toEqual({
+ amount: 128,
+ payment_type: 'wxpay',
+ order_type: 'subscription',
+ plan_id: 7,
+ return_url: 'https://app.example.com/payment/result',
+ is_mobile: false,
+ payment_source: 'wechat_in_app_resume',
+ })
+ })
+})
+
+describe('readPaymentRecoverySnapshot', () => {
+ it('restores an unexpired snapshot when the resume token matches', () => {
+ const snapshot: PaymentRecoverySnapshot = {
+ orderId: 33,
+ amount: 18,
+ qrCode: '',
+ expiresAt: '2099-01-01T00:10:00.000Z',
+ paymentType: 'alipay',
+ payUrl: 'https://pay.example.com/session/33',
+ outTradeNo: 'sub2_33',
+ clientSecret: '',
+ payAmount: 18,
+ orderType: 'balance',
+ paymentMode: 'popup',
+ resumeToken: 'resume-33',
+ createdAt: Date.UTC(2099, 0, 1, 0, 0, 0),
+ }
+
+ const restored = readPaymentRecoverySnapshot(JSON.stringify(snapshot), {
+ now: Date.UTC(2099, 0, 1, 0, 1, 0),
+ resumeToken: 'resume-33',
+ })
+
+ expect(restored?.orderId).toBe(33)
+ })
+
+ it('drops expired or mismatched recovery snapshots', () => {
+ const expiredSnapshot: PaymentRecoverySnapshot = {
+ orderId: 55,
+ amount: 18,
+ qrCode: '',
+ expiresAt: '2024-01-01T00:10:00.000Z',
+ paymentType: 'wxpay',
+ payUrl: 'https://pay.example.com/session/55',
+ outTradeNo: 'sub2_55',
+ clientSecret: '',
+ payAmount: 18,
+ orderType: 'balance',
+ paymentMode: 'popup',
+ resumeToken: 'resume-55',
+ createdAt: Date.UTC(2024, 0, 1, 0, 0, 0),
+ }
+
+ expect(readPaymentRecoverySnapshot(JSON.stringify(expiredSnapshot), {
+ now: Date.UTC(2024, 0, 1, 0, 20, 0),
+ resumeToken: 'resume-55',
+ })).toBeNull()
+
+ expect(readPaymentRecoverySnapshot(JSON.stringify({
+ ...expiredSnapshot,
+ outTradeNo: 'sub2_55',
+ expiresAt: '2099-01-01T00:10:00.000Z',
+ }), {
+ now: Date.UTC(2099, 0, 1, 0, 1, 0),
+ resumeToken: 'other-token',
+ })).toBeNull()
+ })
+
+ it('keeps backward compatibility with snapshots written before outTradeNo existed', () => {
+ const restored = readPaymentRecoverySnapshot(JSON.stringify({
+ orderId: 44,
+ amount: 18,
+ qrCode: '',
+ expiresAt: '2099-01-01T00:10:00.000Z',
+ paymentType: 'alipay',
+ payUrl: 'https://pay.example.com/session/44',
+ clientSecret: '',
+ payAmount: 18,
+ orderType: 'balance',
+ paymentMode: 'popup',
+ resumeToken: 'resume-44',
+ createdAt: Date.UTC(2099, 0, 1, 0, 0, 0),
+ }), {
+ now: Date.UTC(2099, 0, 1, 0, 1, 0),
+ resumeToken: 'resume-44',
+ })
+
+ expect(restored?.orderId).toBe(44)
+ expect(restored?.outTradeNo).toBe('')
+ })
+})
diff --git a/frontend/src/components/payment/__tests__/providerConfig.spec.ts b/frontend/src/components/payment/__tests__/providerConfig.spec.ts
new file mode 100644
index 00000000..ec63726f
--- /dev/null
+++ b/frontend/src/components/payment/__tests__/providerConfig.spec.ts
@@ -0,0 +1,20 @@
+import { describe, expect, it } from 'vitest'
+import { PROVIDER_CONFIG_FIELDS } from '@/components/payment/providerConfig'
+
+function findField(key: string) {
+ const fields = PROVIDER_CONFIG_FIELDS.wxpay || []
+ return fields.find(field => field.key === key)
+}
+
+describe('PROVIDER_CONFIG_FIELDS.wxpay', () => {
+ it('keeps admin form validation aligned with backend-required credentials', () => {
+ expect(findField('publicKeyId')?.optional).toBeFalsy()
+ expect(findField('certSerial')?.optional).toBeFalsy()
+ })
+
+ it('only keeps the simplified visible credential set in the admin form', () => {
+ expect(findField('mpAppId')).toBeUndefined()
+ expect(findField('h5AppName')).toBeUndefined()
+ expect(findField('h5AppUrl')).toBeUndefined()
+ })
+})
diff --git a/frontend/src/components/payment/paymentFlow.ts b/frontend/src/components/payment/paymentFlow.ts
new file mode 100644
index 00000000..318f3882
--- /dev/null
+++ b/frontend/src/components/payment/paymentFlow.ts
@@ -0,0 +1,277 @@
+import type {
+ CreateOrderRequest,
+ CreateOrderResult,
+ MethodLimit,
+ OrderType,
+ WechatJSAPIPayload,
+ WechatOAuthInfo,
+} from '@/types/payment'
+
+export const PAYMENT_RECOVERY_STORAGE_KEY = 'payment.recovery.current'
+
+const VISIBLE_METHOD_ALIASES = {
+ alipay: 'alipay',
+ alipay_direct: 'alipay',
+ wxpay: 'wxpay',
+ wxpay_direct: 'wxpay',
+ stripe: 'stripe',
+} as const
+
+export type VisiblePaymentMethod = 'alipay' | 'wxpay' | 'stripe'
+export type StripeVisibleMethod = 'alipay' | 'wechat_pay'
+export type PaymentLaunchKind =
+ | 'qr_waiting'
+ | 'redirect_waiting'
+ | 'stripe_popup'
+ | 'stripe_route'
+ | 'wechat_oauth'
+ | 'wechat_jsapi'
+ | 'unhandled'
+
+export interface PaymentRecoverySnapshot {
+ orderId: number
+ amount: number
+ qrCode: string
+ expiresAt: string
+ paymentType: string
+ payUrl: string
+ outTradeNo: string
+ clientSecret: string
+ payAmount: number
+ orderType: OrderType | ''
+ paymentMode: string
+ resumeToken: string
+ createdAt: number
+}
+
+export interface PaymentLaunchContext {
+ visibleMethod: string
+ orderType: OrderType
+ isMobile: boolean
+ isWechatBrowser?: boolean
+ now?: number
+ stripePopupUrl?: string
+ stripeRouteUrl?: string
+}
+
+export interface PaymentLaunchDecision {
+ kind: PaymentLaunchKind
+ paymentState: PaymentRecoverySnapshot
+ recovery: PaymentRecoverySnapshot
+ stripeMethod?: StripeVisibleMethod
+ oauth?: WechatOAuthInfo
+ jsapi?: WechatJSAPIPayload
+}
+
+export interface BuildCreateOrderPayloadInput {
+ amount: number
+ paymentType: string
+ orderType: OrderType
+ planId?: number
+ origin?: string
+ isMobile: boolean
+ isWechatBrowser: boolean
+}
+
+type CreateOrderFlowResult = CreateOrderResult & {
+ resume_token?: string
+}
+
+type StorageWriter = Pick
+
+export function normalizeVisibleMethod(method: string): VisiblePaymentMethod | '' {
+ const normalized = VISIBLE_METHOD_ALIASES[method.trim() as keyof typeof VISIBLE_METHOD_ALIASES]
+ return normalized ?? ''
+}
+
+export function getVisibleMethods(methods: Record): Record {
+ const visible: Record = {}
+
+ Object.entries(methods).forEach(([type, limit]) => {
+ const normalized = normalizeVisibleMethod(type)
+ if (!normalized) return
+
+ const isCanonical = type === normalized
+ const existing = visible[normalized]
+ if (!existing || isCanonical) {
+ visible[normalized] = { ...limit }
+ }
+ })
+
+ return visible
+}
+
+export function buildCreateOrderPayload(input: BuildCreateOrderPayloadInput): CreateOrderRequest {
+ const visibleMethod = normalizeVisibleMethod(input.paymentType) || input.paymentType.trim()
+ const normalizedOrigin = (input.origin || '').trim().replace(/\/+$/, '')
+ const payload: CreateOrderRequest = {
+ amount: input.amount,
+ payment_type: visibleMethod,
+ order_type: input.orderType,
+ is_mobile: input.isMobile,
+ payment_source: visibleMethod === 'wxpay' && input.isWechatBrowser
+ ? 'wechat_in_app_resume'
+ : 'hosted_redirect',
+ }
+
+ if (input.planId) {
+ payload.plan_id = input.planId
+ }
+ if (normalizedOrigin) {
+ payload.return_url = `${normalizedOrigin}/payment/result`
+ }
+
+ return payload
+}
+
+export function decidePaymentLaunch(
+ result: CreateOrderFlowResult,
+ context: PaymentLaunchContext,
+): PaymentLaunchDecision {
+ const visibleMethod = normalizeVisibleMethod(context.visibleMethod) || context.visibleMethod
+ const baseState = createPaymentRecoverySnapshot({
+ orderId: result.order_id,
+ amount: result.amount,
+ qrCode: result.qr_code || '',
+ expiresAt: result.expires_at || '',
+ paymentType: visibleMethod,
+ payUrl: result.pay_url || '',
+ outTradeNo: result.out_trade_no || '',
+ clientSecret: result.client_secret || '',
+ payAmount: result.pay_amount,
+ orderType: context.orderType,
+ paymentMode: (result.payment_mode || '').trim(),
+ resumeToken: result.resume_token || '',
+ }, context.now)
+
+ if (baseState.clientSecret) {
+ // visibleMethod === 'stripe' means the user clicked the dedicated Stripe button
+ // and should land on the full Payment Element to choose a sub-method themselves.
+ const isStripeButton = visibleMethod === 'stripe'
+ const stripeMethod: StripeVisibleMethod | undefined = isStripeButton
+ ? undefined
+ : visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay'
+ const kind: PaymentLaunchKind = stripeMethod === 'alipay' && !context.isMobile
+ ? 'stripe_popup'
+ : 'stripe_route'
+ const payUrl = kind === 'stripe_popup'
+ ? context.stripePopupUrl || context.stripeRouteUrl || ''
+ : context.stripeRouteUrl || context.stripePopupUrl || ''
+ const paymentState = { ...baseState, payUrl }
+ return { kind, paymentState, recovery: paymentState, stripeMethod }
+ }
+
+ if (result.result_type === 'oauth_required' && result.oauth?.authorize_url) {
+ return { kind: 'wechat_oauth', paymentState: baseState, recovery: baseState, oauth: result.oauth }
+ }
+
+ const jsapiPayload = result.jsapi ?? result.jsapi_payload
+ if (result.result_type === 'jsapi_ready' && jsapiPayload) {
+ return { kind: 'wechat_jsapi', paymentState: baseState, recovery: baseState, jsapi: jsapiPayload }
+ }
+
+ const normalizedPaymentMode = baseState.paymentMode.trim().toLowerCase()
+ const prefersRedirect = normalizedPaymentMode === 'redirect'
+ || normalizedPaymentMode === 'popup'
+ || (context.isMobile && !!baseState.payUrl)
+ const prefersQr = normalizedPaymentMode === 'qrcode'
+ || normalizedPaymentMode === 'native'
+ || (!prefersRedirect && !!baseState.qrCode)
+
+ if (visibleMethod === 'wxpay' && context.isWechatBrowser && baseState.payUrl && !baseState.qrCode) {
+ return { kind: 'redirect_waiting', paymentState: baseState, recovery: baseState }
+ }
+
+ if (prefersRedirect && baseState.payUrl) {
+ return { kind: 'redirect_waiting', paymentState: baseState, recovery: baseState }
+ }
+
+ if (prefersQr && baseState.qrCode) {
+ return { kind: 'qr_waiting', paymentState: baseState, recovery: baseState }
+ }
+
+ if (baseState.payUrl) {
+ return { kind: 'redirect_waiting', paymentState: baseState, recovery: baseState }
+ }
+
+ return { kind: 'unhandled', paymentState: baseState, recovery: baseState }
+}
+
+export function createPaymentRecoverySnapshot(
+ state: Omit,
+ now = Date.now(),
+): PaymentRecoverySnapshot {
+ return {
+ ...state,
+ createdAt: now,
+ }
+}
+
+export function writePaymentRecoverySnapshot(
+ storage: StorageWriter,
+ snapshot: PaymentRecoverySnapshot,
+ key = PAYMENT_RECOVERY_STORAGE_KEY,
+): void {
+ storage.setItem(key, JSON.stringify(snapshot))
+}
+
+export function clearPaymentRecoverySnapshot(
+ storage: Pick,
+ key = PAYMENT_RECOVERY_STORAGE_KEY,
+): void {
+ storage.removeItem(key)
+}
+
+export function readPaymentRecoverySnapshot(
+ raw: string | null | undefined,
+ options: { now?: number; resumeToken?: string } = {},
+): PaymentRecoverySnapshot | null {
+ if (!raw) return null
+
+ try {
+ const parsed = JSON.parse(raw) as Partial
+ if (
+ typeof parsed.orderId !== 'number'
+ || typeof parsed.amount !== 'number'
+ || typeof parsed.qrCode !== 'string'
+ || typeof parsed.expiresAt !== 'string'
+ || typeof parsed.paymentType !== 'string'
+ || typeof parsed.payUrl !== 'string'
+ || (parsed.outTradeNo != null && typeof parsed.outTradeNo !== 'string')
+ || typeof parsed.clientSecret !== 'string'
+ || typeof parsed.payAmount !== 'number'
+ || typeof parsed.paymentMode !== 'string'
+ || typeof parsed.resumeToken !== 'string'
+ || typeof parsed.createdAt !== 'number'
+ ) {
+ return null
+ }
+
+ const now = options.now ?? Date.now()
+ const expiresAt = Date.parse(parsed.expiresAt)
+ if (Number.isFinite(expiresAt) && expiresAt <= now) {
+ return null
+ }
+ if (options.resumeToken && parsed.resumeToken !== options.resumeToken) {
+ return null
+ }
+
+ return {
+ orderId: parsed.orderId,
+ amount: parsed.amount,
+ qrCode: parsed.qrCode,
+ expiresAt: parsed.expiresAt,
+ paymentType: parsed.paymentType,
+ payUrl: parsed.payUrl,
+ outTradeNo: parsed.outTradeNo || '',
+ clientSecret: parsed.clientSecret,
+ payAmount: parsed.payAmount,
+ orderType: parsed.orderType === 'subscription' ? 'subscription' : 'balance',
+ paymentMode: parsed.paymentMode,
+ resumeToken: parsed.resumeToken,
+ createdAt: parsed.createdAt,
+ }
+ } catch {
+ return null
+ }
+}
diff --git a/frontend/src/components/payment/providerConfig.ts b/frontend/src/components/payment/providerConfig.ts
index a83787fd..f4f5acdc 100644
--- a/frontend/src/components/payment/providerConfig.ts
+++ b/frontend/src/components/payment/providerConfig.ts
@@ -43,11 +43,24 @@ export const METHOD_ORDER = ['alipay', 'alipay_direct', 'wxpay', 'wxpay_direct',
export const PAYMENT_MODE_QRCODE = 'qrcode'
export const PAYMENT_MODE_POPUP = 'popup'
-/** Window features for payment popup windows */
-export const POPUP_WINDOW_FEATURES = 'width=1000,height=750,left=100,top=80,scrollbars=yes,resizable=yes'
+/** Preferred popup size for payment gateways. Alipay's standard checkout
+ * (QR + account login panel) needs ~1200×900 to render without any scrolling. */
+const PAYMENT_POPUP_PREFERRED_WIDTH = 1250
+const PAYMENT_POPUP_PREFERRED_HEIGHT = 900
-/** Wider popup for Stripe redirect methods (Alipay checkout page needs ~1200px) */
-export const STRIPE_POPUP_WINDOW_FEATURES = 'width=1250,height=780,left=80,top=60,scrollbars=yes,resizable=yes'
+/** Build a window.open features string sized to fit within the current screen
+ * while preferring the above dimensions. Centers the popup on the available
+ * work area so nothing is clipped on smaller laptop displays. */
+export function getPaymentPopupFeatures(): string {
+ const screen = typeof window !== 'undefined' ? window.screen : null
+ const availW = screen?.availWidth ?? PAYMENT_POPUP_PREFERRED_WIDTH
+ const availH = screen?.availHeight ?? PAYMENT_POPUP_PREFERRED_HEIGHT
+ const width = Math.min(PAYMENT_POPUP_PREFERRED_WIDTH, availW - 40)
+ const height = Math.min(PAYMENT_POPUP_PREFERRED_HEIGHT, availH - 40)
+ const left = Math.max(0, Math.floor((availW - width) / 2))
+ const top = Math.max(0, Math.floor((availH - height) / 2))
+ return `width=${width},height=${height},left=${left},top=${top},scrollbars=yes,resizable=yes`
+}
/** Webhook paths for each provider (relative to origin). */
export const WEBHOOK_PATHS: Record = {
@@ -86,9 +99,9 @@ export const PROVIDER_CONFIG_FIELDS: Record = {
{ key: 'mchId', label: '', sensitive: false },
{ key: 'privateKey', label: '', sensitive: true },
{ key: 'apiV3Key', label: '', sensitive: true },
+ { key: 'certSerial', label: '', sensitive: false },
{ key: 'publicKey', label: '', sensitive: true },
- { key: 'publicKeyId', label: '', sensitive: false, optional: true },
- { key: 'certSerial', label: '', sensitive: false, optional: true },
+ { key: 'publicKeyId', label: '', sensitive: false },
],
stripe: [
{ key: 'secretKey', label: '', sensitive: true },
diff --git a/frontend/src/components/user/MonitorDetailDialog.vue b/frontend/src/components/user/MonitorDetailDialog.vue
new file mode 100644
index 00000000..564f461b
--- /dev/null
+++ b/frontend/src/components/user/MonitorDetailDialog.vue
@@ -0,0 +1,114 @@
+
+
+
+ {{ t('common.loading') }}
+
+
+ {{ t('channelStatus.detailLoadError') }}
+
+
+
+
+
+ {{ t('channelStatus.detailColumns.model') }}
+ {{ t('channelStatus.detailColumns.latestStatus') }}
+ {{ t('channelStatus.detailColumns.latestLatency') }}
+ {{ t('channelStatus.detailColumns.availability7d') }}
+ {{ t('channelStatus.detailColumns.availability15d') }}
+ {{ t('channelStatus.detailColumns.availability30d') }}
+ {{ t('channelStatus.detailColumns.avgLatency7d') }}
+
+
+
+
+ {{ m.model }}
+
+
+ {{ statusLabel(m.latest_status) }}
+
+
+ {{ formatLatency(m.latest_latency_ms) }}
+ {{ formatPercent(m.availability_7d) }}
+ {{ formatPercent(m.availability_15d) }}
+ {{ formatPercent(m.availability_30d) }}
+ {{ formatLatency(m.avg_latency_7d_ms) }}
+
+
+
+
+
+
+
+
+ {{ t('channelStatus.closeDetail') }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/user/monitor/MonitorAvailabilityRow.vue b/frontend/src/components/user/monitor/MonitorAvailabilityRow.vue
new file mode 100644
index 00000000..34420c9d
--- /dev/null
+++ b/frontend/src/components/user/monitor/MonitorAvailabilityRow.vue
@@ -0,0 +1,49 @@
+
+
+
+ {{ windowLabel }}
+
+
+
+ {{ displayValue }}
+
+ %
+
+
+
+ {{ samplesLabel }}
+
+
+
+
diff --git a/frontend/src/components/user/monitor/MonitorCard.vue b/frontend/src/components/user/monitor/MonitorCard.vue
new file mode 100644
index 00000000..33742c6d
--- /dev/null
+++ b/frontend/src/components/user/monitor/MonitorCard.vue
@@ -0,0 +1,128 @@
+
+
+
+
+
+
+
+
+
+ {{ item.name }}
+
+
+
+ {{ providerLabel(item.provider) }}
+
+
+ {{ item.primary_model }}
+
+
+ {{ item.group_name }}
+
+
+
+
+ {{ statusLabel(item.primary_status) }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/user/monitor/MonitorCardGrid.vue b/frontend/src/components/user/monitor/MonitorCardGrid.vue
new file mode 100644
index 00000000..c7d24c01
--- /dev/null
+++ b/frontend/src/components/user/monitor/MonitorCardGrid.vue
@@ -0,0 +1,81 @@
+
+
+
+
+
diff --git a/frontend/src/components/user/monitor/MonitorHero.vue b/frontend/src/components/user/monitor/MonitorHero.vue
new file mode 100644
index 00000000..bc2b5f6f
--- /dev/null
+++ b/frontend/src/components/user/monitor/MonitorHero.vue
@@ -0,0 +1,116 @@
+
+
+
+
+
+ {{ opt.label }}
+
+
+
+
+
+ {{ overallLabel }}
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/user/monitor/MonitorMetricPair.vue b/frontend/src/components/user/monitor/MonitorMetricPair.vue
new file mode 100644
index 00000000..0f3fd3dc
--- /dev/null
+++ b/frontend/src/components/user/monitor/MonitorMetricPair.vue
@@ -0,0 +1,45 @@
+
+
+
+
+
+ {{ primaryLabel }}
+
+
+ {{ primaryValue }}{{ primaryUnit }}
+
+
+
+
+
+ {{ secondaryLabel }}
+
+
+ {{ secondaryValue }}{{ secondaryUnit }}
+
+
+
+
+
+
diff --git a/frontend/src/components/user/monitor/MonitorTimeline.vue b/frontend/src/components/user/monitor/MonitorTimeline.vue
new file mode 100644
index 00000000..2445bc51
--- /dev/null
+++ b/frontend/src/components/user/monitor/MonitorTimeline.vue
@@ -0,0 +1,115 @@
+
+
+
+ {{ t('monitorCommon.history60pts', { n: length }) }}
+ {{ t('monitorCommon.nextUpdateIn', { n: countdownSeconds }) }}
+
+
+
+ {{ t('monitorCommon.maintenancePaused') }}
+
+
+
+
+ {{ t('monitorCommon.past') }}
+ {{ t('monitorCommon.now') }}
+
+
+
+
+
diff --git a/frontend/src/components/user/monitor/ProviderIcon.vue b/frontend/src/components/user/monitor/ProviderIcon.vue
new file mode 100644
index 00000000..20456a2c
--- /dev/null
+++ b/frontend/src/components/user/monitor/ProviderIcon.vue
@@ -0,0 +1,71 @@
+
+
+
+
+
+ {{ fallbackText }}
+
+
+
+
diff --git a/frontend/src/components/user/profile/ProfileAccountBindingsCard.vue b/frontend/src/components/user/profile/ProfileAccountBindingsCard.vue
new file mode 100644
index 00000000..f1cf54a9
--- /dev/null
+++ b/frontend/src/components/user/profile/ProfileAccountBindingsCard.vue
@@ -0,0 +1,36 @@
+
+
+
+
+
diff --git a/frontend/src/components/user/profile/ProfileAvatarCard.vue b/frontend/src/components/user/profile/ProfileAvatarCard.vue
new file mode 100644
index 00000000..9ff26853
--- /dev/null
+++ b/frontend/src/components/user/profile/ProfileAvatarCard.vue
@@ -0,0 +1,270 @@
+
+
+
+
+ {{ t('profile.avatar.title') }}
+
+
+ {{ t('profile.avatar.description') }}
+
+
+
+
+
+
+
{{ avatarInitial }}
+
+
+
+
+
+ {{ t('profile.avatar.title') }}
+
+
+ {{ displayName }}
+
+
+ {{ t('profile.avatar.uploadHint') }}
+
+
+
+
+
+
+ {{ t('profile.avatar.uploadAction') }}
+
+
+
+ {{ t('common.save') }}
+
+
+
+ {{ t('common.delete') }}
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/user/profile/ProfileEditForm.vue b/frontend/src/components/user/profile/ProfileEditForm.vue
index 2750840a..e1441921 100644
--- a/frontend/src/components/user/profile/ProfileEditForm.vue
+++ b/frontend/src/components/user/profile/ProfileEditForm.vue
@@ -1,12 +1,20 @@
-
-
+
+
{{ t('profile.editProfile') }}
-