diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 22f42455..2f79a551 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -15,12 +15,12 @@ "lucide-react": "^0.539.0", "react": "^18.1.1", "react-dom": "^18.1.1", + "react-force-graph-2d": "^1.29.0", "react-force-graph-3d": "^1.29.0", "react-redux": "^9.2.0", "react-router": "^7.8.0", "recharts": "2.15.0", "tailwind-merge": "^2.5.2", - "three": "^0.182.0", "three-spritetext": "^1.8.4" }, "devDependencies": { @@ -2768,6 +2768,16 @@ "dev": true, "license": "MIT" }, + "node_modules/bezier-js": { + "version": "6.1.4", + "resolved": "https://registry.npmjs.org/bezier-js/-/bezier-js-6.1.4.tgz", + "integrity": "sha512-PA0FW9ZpcHbojUCMu28z9Vg/fNkwTj5YhusSAjHHDfHDGLxJ6YUKrAN2vk1fP2MMOxVw4Oko16FMlRGVBGqLKg==", + "license": "MIT", + "funding": { + "type": "individual", + "url": "https://github.com/Pomax/bezierjs/blob/master/FUNDING.md" + } + }, "node_modules/binary-extensions": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz", @@ -2938,6 +2948,18 @@ ], "license": "CC-BY-4.0" }, + "node_modules/canvas-color-tracker": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/canvas-color-tracker/-/canvas-color-tracker-1.3.2.tgz", + "integrity": "sha512-ryQkDX26yJ3CXzb3hxUVNlg1NKE4REc5crLBq661Nxzr8TNd236SaEf2ffYLXyI5tSABSeguHLqcVq4vf9L3Zg==", + "license": "MIT", + "dependencies": { + "tinycolor2": "^1.6.0" + }, + "engines": { + "node": ">=12" + } + }, "node_modules/chalk": { "version": "4.1.2", "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", @@ -4094,6 +4116,32 @@ "node": ">=12" } }, + "node_modules/force-graph": { + "version": "1.51.0", + "resolved": "https://registry.npmjs.org/force-graph/-/force-graph-1.51.0.tgz", + "integrity": "sha512-aTnihCmiMA0ItLJLCbrQYS9mzriopW24goFPgUnKAAmAlPogTSmFWqoBPMXzIfPb7bs04Hur5zEI4WYgLW3Sig==", + "license": "MIT", + "dependencies": { + "@tweenjs/tween.js": "18 - 25", + "accessor-fn": "1", + "bezier-js": "3 - 6", + "canvas-color-tracker": "^1.3", + "d3-array": "1 - 3", + "d3-drag": "2 - 3", + "d3-force-3d": "2 - 3", + "d3-scale": "1 - 4", + "d3-scale-chromatic": "1 - 3", + "d3-selection": "2 - 3", + "d3-zoom": "2 - 3", + "float-tooltip": "^1.7", + "index-array-by": "1", + "kapsule": "^1.16", + "lodash-es": "4" + }, + "engines": { + "node": ">=12" + } + }, "node_modules/forwarded": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz", @@ -4386,6 +4434,15 @@ "node": ">=0.8.19" } }, + "node_modules/index-array-by": { + "version": "1.4.2", + "resolved": "https://registry.npmjs.org/index-array-by/-/index-array-by-1.4.2.tgz", + "integrity": "sha512-SP23P27OUKzXWEC/TOyWlwLviofQkCSCKONnc62eItjp69yCZZPqDQtr3Pw5gJDnPeUMqExmKydNZaJO0FU9pw==", + "license": "MIT", + "engines": { + "node": ">=12" + } + }, "node_modules/inherits": { "version": "2.0.4", "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", @@ -6299,6 +6356,23 @@ "react": "^18.3.1" } }, + "node_modules/react-force-graph-2d": { + "version": "1.29.0", + "resolved": "https://registry.npmjs.org/react-force-graph-2d/-/react-force-graph-2d-1.29.0.tgz", + "integrity": "sha512-Xv5IIk+hsZmB3F2ibja/t6j/b0/1T9dtFOQacTUoLpgzRHrO6wPu1GtQ2LfRqI/imgtaapnXUgQaE8g8enPo5w==", + "license": "MIT", + "dependencies": { + "force-graph": "^1.51", + "prop-types": "15", + "react-kapsule": "^2.5" + }, + "engines": { + "node": ">=12" + }, + "peerDependencies": { + "react": "*" + } + }, "node_modules/react-force-graph-3d": { "version": "1.29.0", "resolved": "https://registry.npmjs.org/react-force-graph-3d/-/react-force-graph-3d-1.29.0.tgz", diff --git a/frontend/package.json b/frontend/package.json index 10fbb1ca..fe3073a0 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -18,12 +18,12 @@ "lucide-react": "^0.539.0", "react": "^18.1.1", "react-dom": "^18.1.1", + "react-force-graph-2d": "^1.29.0", "react-force-graph-3d": "^1.29.0", "react-redux": "^9.2.0", "react-router": "^7.8.0", "recharts": "2.15.0", "tailwind-merge": "^2.5.2", - "three": "^0.182.0", "three-spritetext": "^1.8.4" }, "devDependencies": { diff --git a/frontend/src/pages/KnowledgeBase/components/KnowledgeGraphView.tsx b/frontend/src/pages/KnowledgeBase/components/KnowledgeGraphView.tsx index 3539747e..992d90f0 100644 --- a/frontend/src/pages/KnowledgeBase/components/KnowledgeGraphView.tsx +++ b/frontend/src/pages/KnowledgeBase/components/KnowledgeGraphView.tsx @@ -1,9 +1,6 @@ -import { useMemo, useRef, useEffect, useCallback } from "react"; -import ForceGraph3D, { ForceGraphMethods } from "react-force-graph-3d"; -import type { KnowledgeGraphEdge, KnowledgeGraphNode } from "../knowledge-base.model"; -import { Empty } from "antd"; -import * as THREE from "three"; -import SpriteText from "three-spritetext"; +import React, {useMemo, useRef, useEffect} from "react"; +import ForceGraph2D from "react-force-graph-2d"; +import type {KnowledgeGraphEdge, KnowledgeGraphNode} from "../knowledge-base.model"; export type GraphEntitySelection = | { type: "node"; data: KnowledgeGraphNode } @@ -16,356 +13,167 @@ interface KnowledgeGraphViewProps { onSelectEntity?: (selection: GraphEntitySelection | null) => void; } -const KnowledgeGraphView: React.FC = ({ - nodes, - edges, - height = 520, - onSelectEntity, -}) => { - const graphRef = useRef(); - const lightingInitializedRef = useRef(false); - - const degreeMap = useMemo(() => { - const map = new Map(); - edges.forEach((edge) => { - map.set(String(edge.source), (map.get(String(edge.source)) || 0) + 1); - map.set(String(edge.target), (map.get(String(edge.target)) || 0) + 1); - }); - return map; - }, [edges]); - - const graphData = useMemo( - () => ({ - nodes: nodes.map((node) => ({ ...node })), - links: edges.map((edge) => { - const enrichedEdge = { - ...edge, - source: edge.source, - target: edge.target, - keywords: edge.properties?.keywords || edge.type, - } as any; - enrichedEdge.__originalEdge = edge; - return enrichedEdge; - }), - }), - [nodes, edges] - ); - - const handleLinkSelect = useCallback( - (link: any) => { - onSelectEntity?.({ type: "edge", data: normalizeLinkData(link) }); - }, - [onSelectEntity] - ); +const COLOR_PALETTE = ["#60a5fa", "#f87171", "#fbbf24", "#34d399", "#a78bfa", "#fb7185", "#22d3ee", "#818cf8", "#fb923c", "#4ade80"]; - useEffect(() => { - graphRef.current?.zoomToFit(800); - }, [graphData]); +const KnowledgeGraphView: React.FC = ({ + nodes, + edges, + height = 520, + onSelectEntity, + }) => { + const graphRef = useRef(); useEffect(() => { - if (lightingInitializedRef.current) return; - const graph = graphRef.current; - const scene = graph?.scene?.(); - if (!scene) return; - - const ambient = new THREE.AmbientLight(0xffffff, 0.35); - const key = new THREE.DirectionalLight(0xffffff, 0.8); - key.position.set(120, 160, 220); - const rim = new THREE.DirectionalLight(0x3b82f6, 0.5); - rim.position.set(-140, -120, -180); - const fill = new THREE.DirectionalLight(0xffffff, 0.45); - fill.position.set(-60, 40, 140); - - ambient.name = "kg-ambient-light"; - key.name = "kg-key-light"; - rim.name = "kg-rim-light"; - fill.name = "kg-fill-light"; - - scene.add(ambient, key, rim, fill); - lightingInitializedRef.current = true; - - return () => { - scene.remove(ambient); - scene.remove(key); - scene.remove(rim); - scene.remove(fill); - lightingInitializedRef.current = false; - }; - }, [graphData]); - - if (!nodes.length) { - return ( -
- -
- ); - } + if (graphRef.current) { + // 1. 调整力导向平衡:减小斥力让独立图块靠近,增加向心力防止飘散 + graphRef.current.d3Force("charge").strength(-250); // 斥力适中 + graphRef.current.d3Force("link").distance(120); // 边长适中 + graphRef.current.d3Force("center").strength(0.8); // 增强向心力,让孤立集群往中间靠 + } + }, [nodes]); + + const typeColorMap = useMemo(() => { + const map = new Map(); + const types = Array.from(new Set(nodes.map(n => n.properties?.entity_type || (n.labels && n.labels[0]) || 'default'))); + types.forEach((type, i) => map.set(type, COLOR_PALETTE[i % COLOR_PALETTE.length])); + return map; + }, [nodes]); + + const graphData = useMemo(() => ({ + nodes: nodes.map((node) => ({ + ...node, + color: typeColorMap.get(node.properties?.entity_type || (node.labels && node.labels[0]) || 'default'), + val: 8 // 统一基础大小,使视觉更整洁 + })), + links: edges.map((edge) => ({ + ...edge, + __originalEdge: edge, + keywords: edge.properties?.keywords || edge.type || "" + })), + }), [nodes, edges, typeColorMap]); return ( -
- + "rgba(14,165,233,0.9)"} - linkWidth={(link: any) => { - const weight = Number(link.properties?.weight ?? link.properties?.score ?? 1); - return Math.min(1.2 + weight * 0.4, 4); - }} - linkDirectionalParticles={2} - linkDirectionalParticleWidth={3.5} - linkDirectionalParticleSpeed={0.0035} - linkDirectionalParticleColor={() => "rgba(248,250,252,0.85)"} - linkCurvature={0.25} - d3VelocityDecay={0.18} - linkDistance={(link: any) => computeLinkDistance(link, degreeMap)} - nodeAutoColorBy={(node: any) => node.properties?.entity_type || "default"} - nodeOpacity={1} - nodeLabel={(node: any) => node.id} - linkLabel={(link: any) => link.keywords} - nodeThreeObject={(node: any) => { - const radius = getNodeRadius(node.id, degreeMap); - const color = node.color || "#60a5fa"; - const group = new THREE.Group(); - const sphereRadius = getSphereDisplayRadius(radius); - const baseColor = new THREE.Color(color); - - const litSphere = new THREE.Mesh(getSphereGeometry(sphereRadius), getSphereMaterial(baseColor)); - group.add(litSphere); - const innerSphere = new THREE.Mesh( - getSphereGeometry(Math.max(sphereRadius * 0.65, 0.6)), - new THREE.MeshLambertMaterial({ - color: baseColor.clone().offsetHSL(0, 0, 0.15), - emissive: baseColor.clone().multiplyScalar(0.2), - transparent: true, - opacity: 0.75, - }) - ); - innerSphere.renderOrder = 2; - group.add(innerSphere); - - const highlightOrb = createHighlightOrb(sphereRadius, baseColor); - if (highlightOrb) { - group.add(highlightOrb); - } - - const label = new SpriteText(node.id || "", 1, "#f8fafc"); - label.center.set(0.5, 0.5); - label.material.depthWrite = false; - label.material.depthTest = false; - label.renderOrder = 50; - const maxDiameter = radius * 0.95; - const fontRatio = Math.max(Math.min((radius / 18) * 5, 5), 1.5) * 1.15; - label.textHeight = Math.min(maxDiameter, radius * 0.7) / fontRatio; - label.position.set(0, 0, sphereRadius + label.textHeight * 0.95); - group.add(label); - - return group; - }} - linkThreeObjectExtend={true} - linkThreeObject={(link: any) => { - const text = String(link.keywords || "").trim(); - if (!text) { - return new THREE.Object3D(); + // --- 边视觉 --- + linkColor={() => "rgba(255, 255, 255, 0.2)"} + linkWidth={1.2} + linkDirectionalArrowLength={3} + linkDirectionalArrowRelPos={1} + linkCurvature={0.1} + + // --- 节点绘制 --- + nodeCanvasObject={(node: any, ctx, globalScale) => { + const {x, y, val: radius, color, id} = node; + if (!Number.isFinite(x) || !Number.isFinite(y)) return; + + ctx.save(); + ctx.beginPath(); + ctx.arc(x, y, radius, 0, 2 * Math.PI); + ctx.fillStyle = color; + ctx.shadowBlur = 10 / globalScale; + ctx.shadowColor = color; + ctx.fill(); + + // 节点名称 + if (globalScale > 0.4) { + const fontSize = 12 / globalScale; + ctx.font = `${fontSize}px Sans-Serif`; + ctx.textAlign = 'center'; + ctx.textBaseline = 'top'; + ctx.fillStyle = '#ffffff'; + ctx.shadowBlur = 0; + ctx.fillText(id, x, y + radius + 2); } - const label = new SpriteText(text, 1, "#e2e8f0"); - label.center.set(0.5, 0.5); - label.material.depthWrite = false; - label.material.depthTest = false; - label.renderOrder = 15; - label.textHeight = 4; - (label as any).__graphObjType = "link"; - (label as any).__data = link; - label.userData.normalizedEdge = normalizeLinkData(link); - return label; + ctx.restore(); }} - linkPositionUpdate={(sprite, { start, end }) => { - const middlePos = { - x: start.x + (end.x - start.x) / 2, - y: start.y + (end.y - start.y) / 2, - z: start.z + (end.z - start.z) / 2, - }; - Object.assign(sprite.position, middlePos); - const dx = end.x - start.x; - const dy = end.y - start.y; - const angle = Math.atan2(dy, dx); - const material = (sprite as SpriteText).material as THREE.SpriteMaterial | undefined; - if (material) { - material.rotation = angle; - } - }} - onNodeClick={(node: any) => onSelectEntity?.({ type: "node", data: node })} - onLinkClick={handleLinkSelect} - onBackgroundClick={() => onSelectEntity?.(null)} - /> -
- ); -}; - -export default KnowledgeGraphView; - -const circleTextureCache = new Map(); -const sphereMaterialCache = new Map(); -const sphereGeometryCache = new Map(); -function getCircleTexture(color: string, opacity = 1, soft = false) { - const key = `${color}-${opacity}-${soft}`; - if (circleTextureCache.has(key)) { - return circleTextureCache.get(key)!; - } - const size = 512; - const canvas = document.createElement("canvas"); - canvas.width = size; - canvas.height = size; - const ctx = canvas.getContext("2d"); - if (!ctx) return null; - - ctx.clearRect(0, 0, size, size); - if (soft) { - const gradient = ctx.createRadialGradient(size / 2, size / 2, size / 3, size / 2, size / 2, size / 2); - gradient.addColorStop(0, hexToRgba(color, opacity * 0.15)); - gradient.addColorStop(1, hexToRgba(color, 0)); - ctx.fillStyle = gradient; - ctx.fillRect(0, 0, size, size); - } else { - ctx.fillStyle = hexToRgba(color, opacity); - ctx.beginPath(); - ctx.arc(size / 2, size / 2, size / 2, 0, Math.PI * 2); - ctx.closePath(); - ctx.fill(); - } - - const texture = new THREE.CanvasTexture(canvas); - texture.needsUpdate = true; - circleTextureCache.set(key, texture); - return texture; -} + linkPointerAreaPaint={(link: any, color, ctx, globalScale) => { + const label = link.keywords; + if (!label || globalScale < 1.1) return; -function hexToRgba(hex: string, alpha: number) { - const parsedHex = hex.replace("#", ""); - const bigint = Number.parseInt(parsedHex.length === 3 ? parsedHex.repeat(2) : parsedHex, 16); - const r = (bigint >> 16) & 255; - const g = (bigint >> 8) & 255; - const b = bigint & 255; - return `rgba(${r}, ${g}, ${b}, ${alpha})`; -} + const start = link.source; + const end = link.target; + if (typeof start !== 'object' || typeof end !== 'object') return; -function createNodeLabelTexture(text: string, radius: number) { - return null; -} + const fontSize = 9 / globalScale; + const textPos = {x: start.x + (end.x - start.x) * 0.5, y: start.y + (end.y - start.y) * 0.5}; + const angle = Math.atan2(end.y - start.y, end.x - start.x); + const bRotate = angle > Math.PI / 2 || angle < -Math.PI / 2; -function createEdgeLabelTexture(text: string) { - return createTextTexture(text, { - fontSize: 10, - paddingX: 4, - paddingY: 2, - backgroundFill: null, - textFill: "rgba(241,245,249,0.9)", - maxWidth: 60, - }); -} + ctx.save(); + ctx.translate(textPos.x, textPos.y); + ctx.rotate(bRotate ? angle + Math.PI : angle); -function getNodeRadius(nodeId: string, degreeMap: Map) { - const degree = degreeMap.get(nodeId) || 1; - return Math.min(12 + degree * 4, 64); -} + ctx.font = `${fontSize}px Sans-Serif`; + const textWidth = ctx.measureText(label).width; -function getSphereDisplayRadius(nodeRadius: number) { - return Math.max(nodeRadius * 0.16, 2.2); -} + // 绘制一个与文字大小相同的透明矩形,颜色必须使用参数中的 'color' + // 这是 react-force-graph 识别点击对象的关键(Color-picking 技术) + ctx.fillStyle = color; + ctx.fillRect(-textWidth / 2 - 2, -fontSize / 2 - 2, textWidth + 4, fontSize + 4); + ctx.restore(); + }} -function getSphereGeometry(radius: number) { - const key = Number(radius.toFixed(2)); - if (!sphereGeometryCache.has(key)) { - sphereGeometryCache.set(key, new THREE.SphereGeometry(radius, 48, 48)); - } - return sphereGeometryCache.get(key)!; -} + // --- 边文字绘制:优化大小、位置和翻转逻辑 --- + linkCanvasObjectMode={() => 'after'} + linkCanvasObject={(link: any, ctx, globalScale) => { + const MAX_DISPLAY_SCALE = 1.1; + if (globalScale < MAX_DISPLAY_SCALE) return; -function getSphereMaterial(color: THREE.Color) { - const key = color.getHexString(); - if (!sphereMaterialCache.has(key)) { - const specular = new THREE.Color(1, 1, 1).lerp(color.clone(), 0.35); - sphereMaterialCache.set( - key, - new THREE.MeshPhongMaterial({ - color: color.clone(), - emissive: color.clone().multiplyScalar(0.12), - specular, - shininess: 85, - reflectivity: 0.4, - }) - ); - } - return sphereMaterialCache.get(key)!; -} + const label = link.keywords; + const start = link.source; + const end = link.target; + if (typeof start !== 'object' || typeof end !== 'object') return; -function createHighlightOrb(sphereRadius: number, baseColor: THREE.Color) { - const orbRadius = Math.max(sphereRadius * 0.28, 0.4); - const geometry = getSphereGeometry(orbRadius); - const material = new THREE.MeshBasicMaterial({ - color: baseColor.clone().offsetHSL(0, 0, 0.35), - transparent: true, - opacity: 0.85, - }); - const orb = new THREE.Mesh(geometry, material); - orb.position.set(sphereRadius * 0.45, sphereRadius * 0.5, sphereRadius * 0.65); - orb.renderOrder = 6; - return orb; -} + // 边文字比节点文字小一点点(节点12,边11) + const fontSize = 11 / globalScale; -function normalizeLinkData(link: any): KnowledgeGraphEdge { - if (!link) { - return { - id: "", - type: "", - source: "", - target: "", - properties: {}, - }; - } + const textPos = { + x: start.x + (end.x - start.x) * 0.5, + y: start.y + (end.y - start.y) * 0.5 + }; - if ((link as any).__normalizedEdge) { - return (link as any).__normalizedEdge as KnowledgeGraphEdge; - } + let angle = Math.atan2(end.y - start.y, end.x - start.x); - const normalized: KnowledgeGraphEdge = { - id: String(link.id ?? link.__id ?? ""), - type: String(link.type ?? ""), - source: extractNodeId(link.source), - target: extractNodeId(link.target), - properties: { ...(link.properties ?? {}) }, - }; + // --- 核心修复:防止文字倒挂 --- + // 如果角度在 90度 到 270度 之间,旋转180度让文字保持正向 + const bRotate = angle > Math.PI / 2 || angle < -Math.PI / 2; - if (link.keywords && !normalized.properties.keywords) { - (normalized.properties as Record).keywords = link.keywords; - } + ctx.save(); + ctx.translate(textPos.x, textPos.y); + ctx.rotate(bRotate ? angle + Math.PI : angle); - (link as any).__normalizedEdge = normalized; - return normalized; -} + ctx.font = `${fontSize}px Sans-Serif`; + const textWidth = ctx.measureText(label).width; -function extractNodeId(nodeRef: any) { - if (nodeRef == null) return ""; - if (typeof nodeRef === "string" || typeof nodeRef === "number") { - return String(nodeRef); - } - return String(nodeRef.id ?? nodeRef.__id ?? nodeRef.name ?? ""); -} + // 绘制极小的背景遮罩,紧贴文字 + ctx.fillStyle = 'rgba(1, 3, 15, 0.7)'; + ctx.fillRect(-textWidth / 2 - 1, -fontSize / 2, textWidth + 2, fontSize); -function computeLinkDistance(link: any, degreeMap: Map) { - const sourceId = extractNodeId(link.source); - const targetId = extractNodeId(link.target); - const sourceRadius = getNodeRadius(sourceId, degreeMap); - const targetRadius = getNodeRadius(targetId, degreeMap); - const minimumGap = (sourceRadius + targetRadius) * 5; + ctx.fillStyle = '#94e2d5'; + ctx.textAlign = 'center'; + ctx.textBaseline = 'middle'; + // y轴偏移设为0,使其紧贴线条中心 + ctx.fillText(label, 0, 0); + ctx.restore(); + }} - const degreeBoost = ((degreeMap.get(sourceId) || 1) + (degreeMap.get(targetId) || 1)) / 2; - const weight = Number(link.properties?.weight ?? link.properties?.score ?? 1); - const base = 260; - const dynamicDistance = base + degreeBoost * 55 + weight * 40; + onNodeClick={(node: any) => onSelectEntity?.({type: "node", data: node})} + onLinkClick={(link: any) => { + const originalData = link.__originalEdge || link; + onSelectEntity?.({type: "edge", data: originalData}); + }} + onBackgroundClick={() => onSelectEntity?.(null)} + cooldownTicks={120} + d3VelocityDecay={0.4} // 增加阻力,使布局更快稳定 + /> + + ); +}; - return Math.min(Math.max(dynamicDistance, minimumGap) * 100, 500); -} +export default KnowledgeGraphView; diff --git a/runtime/datamate-python/app/module/rag/service/rag_service.py b/runtime/datamate-python/app/module/rag/service/rag_service.py index adf5dae6..67cdfc15 100644 --- a/runtime/datamate-python/app/module/rag/service/rag_service.py +++ b/runtime/datamate-python/app/module/rag/service/rag_service.py @@ -18,6 +18,7 @@ build_llm_model_func, initialize_rag, ) +from ...system.service.common_service import get_embedding_dimension, get_openai_client logger = get_logger(__name__) @@ -53,7 +54,7 @@ async def init_graph_rag(self, knowledge_base_id: str): embedding_model.model_name, embedding_model.base_url, embedding_model.api_key, - embedding_dim=embedding_model.embedding_dim if hasattr(embedding_model, "embedding_dim") else 1024, + embedding_dim=get_embedding_dimension(get_openai_client(embedding_model)), ) kb_working_dir = os.path.join(DEFAULT_WORKING_DIR, kb.name) diff --git a/runtime/datamate-python/app/module/system/service/common_service.py b/runtime/datamate-python/app/module/system/service/common_service.py index 22ed6ecd..2cf5578b 100644 --- a/runtime/datamate-python/app/module/system/service/common_service.py +++ b/runtime/datamate-python/app/module/system/service/common_service.py @@ -1,7 +1,7 @@ from typing import Optional from langchain_core.language_models import BaseChatModel -from langchain_openai import ChatOpenAI +from langchain_openai import ChatOpenAI, OpenAIEmbeddings from pydantic import SecretStr from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -27,3 +27,17 @@ def chat(model: BaseChatModel, prompt: str) -> str: """使用指定模型进行聊天""" response = model.invoke(prompt) return response.content + + +# 实例化对象 +def get_openai_client(model: ModelConfig) -> OpenAIEmbeddings: + return OpenAIEmbeddings( + model=model.model_name, + base_url=model.base_url, + api_key=SecretStr(model.api_key), + ) + +# 获取嵌入向量维度 +def get_embedding_dimension(model: OpenAIEmbeddings) -> int: + """获取 OpenAI 模型的嵌入向量维度""" + return len(model.embed_query(model.model))