From 08c8fb265ecd627b0977307ee12ba7a8c3a2b037 Mon Sep 17 00:00:00 2001 From: sid597 Date: Mon, 25 May 2026 12:10:03 +0530 Subject: [PATCH 1/5] ENG-1477: Convert a tldraw arrow to a DG relation Right-click a plain tldraw arrow connecting two discourse-graph nodes to get a "Relation" submenu listing the relation types valid between those node types. Selecting one creates the discourse relation (persisted to Roam) and replaces the plain arrow with the relation arrow. Extract the relation-creation logic shared with the node drag-handle flow into overlays/relationCreation.ts (createRelationBetweenNodes, getValidRelationTypesBetween); DragHandleOverlay and RelationTypeDropdown now delegate to it. --- .../canvas/overlays/DragHandleOverlay.tsx | 111 +------------ .../canvas/overlays/RelationTypeDropdown.tsx | 56 +------ .../canvas/overlays/relationCreation.ts | 147 ++++++++++++++++++ .../src/components/canvas/uiOverrides.tsx | 68 ++++++++ 4 files changed, 228 insertions(+), 154 deletions(-) create mode 100644 apps/roam/src/components/canvas/overlays/relationCreation.ts diff --git a/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx b/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx index 37d25869c..b6c3ec428 100644 --- a/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx +++ b/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx @@ -1,20 +1,13 @@ import React, { useCallback, useEffect, useRef, useState } from "react"; -import { TLShapeId, createShapeId, useEditor, useValue } from "tldraw"; +import { TLShapeId, useEditor, useValue } from "tldraw"; import { DiscourseNodeShape } from "~/components/canvas/DiscourseNodeUtil"; import { - BaseDiscourseRelationUtil, - DiscourseRelationShape, - getRelationColor, -} from "~/components/canvas/DiscourseRelationShape/DiscourseRelationUtil"; -import { createOrUpdateArrowBinding } from "~/components/canvas/DiscourseRelationShape/helpers"; -import { - checkConnectionType, - getAllRelations, hasValidRelationTypes, isDiscourseNodeShape, } from "~/components/canvas/canvasUtils"; import { dispatchToastEvent } from "~/components/canvas/ToastListener"; import { RelationTypeDropdown } from "./RelationTypeDropdown"; +import { createRelationBetweenNodes } from "./relationCreation"; const HANDLE_RADIUS = 5; const HANDLE_HIT_AREA = 12; @@ -256,103 +249,13 @@ export const DragHandleOverlay = () => { (relationId: string) => { if (!pending) return; - const selectedRelation = getAllRelations().find( - (r) => r.id === relationId, - ); - if (!selectedRelation) { - setPending(null); - sourceNodeRef.current = null; - return; - } - - const color = getRelationColor(selectedRelation.label); - - // Determine direction: if we dragged from the relation's destination type, - // the arrow is in reverse and should display the complement label. - const sourceNode = editor.getShape(pending.sourceId); - const targetNode = editor.getShape(pending.targetId); - const { isReverse } = checkConnectionType( - selectedRelation, - sourceNode?.type ?? "", - targetNode?.type ?? "", - ); - const label = - isReverse && selectedRelation.complement - ? selectedRelation.complement - : selectedRelation.label; - - // Get source bounds for arrow positioning - const sourceBounds = editor.getShapePageBounds(pending.sourceId); - if (!sourceBounds) { - setPending(null); - sourceNodeRef.current = null; - return; - } - - // Create the real relation shape with the correct type - const arrowId = createShapeId(); - editor.createShape({ - id: arrowId, - type: relationId, - x: sourceBounds.midX, - y: sourceBounds.midY, - props: { - color, - text: label, - dash: "draw", - size: "m", - fill: "none", - bend: 0, - start: { x: 0, y: 0 }, - end: { x: 0, y: 0 }, - arrowheadStart: "none", - arrowheadEnd: "arrow", - labelPosition: 0.5, - font: "draw", - scale: 1, - }, - }); - - const newArrow = editor.getShape(arrowId); - if (!newArrow) { - setPending(null); - sourceNodeRef.current = null; - return; - } - - // Bind start and end - createOrUpdateArrowBinding(editor, newArrow, pending.sourceId, { - terminal: "start", - normalizedAnchor: { x: 0.5, y: 0.5 }, - isPrecise: false, - isExact: false, + createRelationBetweenNodes({ + editor, + relationId, + sourceId: pending.sourceId, + targetId: pending.targetId, }); - createOrUpdateArrowBinding(editor, newArrow, pending.targetId, { - terminal: "end", - normalizedAnchor: { x: 0.5, y: 0.5 }, - isPrecise: false, - isExact: false, - }); - - // Persist via handleCreateRelationsInRoam - const util = editor.getShapeUtil(newArrow); - if ( - util instanceof BaseDiscourseRelationUtil && - "handleCreateRelationsInRoam" in util - ) { - type UtilWithRoamPersistence = BaseDiscourseRelationUtil & { - handleCreateRelationsInRoam: (args: { - arrow: DiscourseRelationShape; - targetId: TLShapeId; - }) => Promise; - }; - void (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({ - arrow: editor.getShape(arrowId) ?? newArrow, - targetId: pending.targetId, - }); - } - editor.select(arrowId); setPending(null); sourceNodeRef.current = null; }, diff --git a/apps/roam/src/components/canvas/overlays/RelationTypeDropdown.tsx b/apps/roam/src/components/canvas/overlays/RelationTypeDropdown.tsx index 758a94033..f4997df55 100644 --- a/apps/roam/src/components/canvas/overlays/RelationTypeDropdown.tsx +++ b/apps/roam/src/components/canvas/overlays/RelationTypeDropdown.tsx @@ -1,11 +1,6 @@ import React, { useCallback, useEffect, useMemo, useRef } from "react"; -import { TLShapeId, useEditor, DefaultColorThemePalette } from "tldraw"; -import { getRelationColor } from "~/components/canvas/DiscourseRelationShape/DiscourseRelationUtil"; -import { - checkConnectionType, - getAllRelations, - isDiscourseNodeShape, -} from "~/components/canvas/canvasUtils"; +import { TLShapeId, useEditor } from "tldraw"; +import { getValidRelationTypesBetween } from "./relationCreation"; type RelationTypeDropdownProps = { sourceId: TLShapeId; @@ -25,49 +20,10 @@ export const RelationTypeDropdown = ({ const editor = useEditor(); const dropdownRef = useRef(null); - // Get valid relation types based on source/target node types - const validRelationTypes = useMemo(() => { - const startNode = editor.getShape(sourceId); - const endNode = editor.getShape(targetId); - if (!startNode || !endNode) return []; - - const startNodeType = startNode.type; - const endNodeType = endNode.type; - - // Verify both are discourse nodes - if ( - !isDiscourseNodeShape(editor, startNode) || - !isDiscourseNodeShape(editor, endNode) - ) - return []; - - const colorPalette = DefaultColorThemePalette.lightMode; - const validTypes: { id: string; label: string; color: string }[] = []; - const allRelations = getAllRelations(); - const seenLabels = new Set(); - - for (const relation of allRelations) { - const { isDirect: isForward, isReverse } = checkConnectionType( - relation, - startNodeType, - endNodeType, - ); - - if (!isForward && !isReverse) continue; - - const label = - isReverse && relation.complement ? relation.complement : relation.label; - - if (!seenLabels.has(label)) { - seenLabels.add(label); - const tldrawColor = getRelationColor(relation.label); - const hexColor = colorPalette[tldrawColor]?.solid ?? "#333"; - validTypes.push({ id: relation.id, label, color: hexColor }); - } - } - - return validTypes; - }, [editor, sourceId, targetId]); + const validRelationTypes = useMemo( + () => getValidRelationTypesBetween(editor, sourceId, targetId), + [editor, sourceId, targetId], + ); // Handle click outside useEffect(() => { diff --git a/apps/roam/src/components/canvas/overlays/relationCreation.ts b/apps/roam/src/components/canvas/overlays/relationCreation.ts new file mode 100644 index 000000000..304bf4e6e --- /dev/null +++ b/apps/roam/src/components/canvas/overlays/relationCreation.ts @@ -0,0 +1,147 @@ +import { + DefaultColorThemePalette, + Editor, + TLShapeId, + createShapeId, +} from "tldraw"; +import { + BaseDiscourseRelationUtil, + DiscourseRelationShape, + getRelationColor, +} from "~/components/canvas/DiscourseRelationShape/DiscourseRelationUtil"; +import { createOrUpdateArrowBinding } from "~/components/canvas/DiscourseRelationShape/helpers"; +import { + checkConnectionType, + getAllRelations, + isDiscourseNodeShape, +} from "~/components/canvas/canvasUtils"; + +type RelationTypeOption = { id: string; label: string; color: string }; + +export const getValidRelationTypesBetween = ( + editor: Editor, + startId: TLShapeId, + endId: TLShapeId, +): RelationTypeOption[] => { + const startNode = editor.getShape(startId); + const endNode = editor.getShape(endId); + if (!startNode || !endNode) return []; + if ( + !isDiscourseNodeShape(editor, startNode) || + !isDiscourseNodeShape(editor, endNode) + ) + return []; + + const colorPalette = DefaultColorThemePalette.lightMode; + const validTypes: RelationTypeOption[] = []; + const seenLabels = new Set(); + + for (const relation of getAllRelations()) { + const { isDirect, isReverse } = checkConnectionType( + relation, + startNode.type, + endNode.type, + ); + if (!isDirect && !isReverse) continue; + + const label = + isReverse && relation.complement ? relation.complement : relation.label; + if (seenLabels.has(label)) continue; + seenLabels.add(label); + + const hexColor = + colorPalette[getRelationColor(relation.label)]?.solid ?? "#333"; + validTypes.push({ id: relation.id, label, color: hexColor }); + } + + return validTypes; +}; + +export const createRelationBetweenNodes = ({ + editor, + relationId, + sourceId, + targetId, +}: { + editor: Editor; + relationId: string; + sourceId: TLShapeId; + targetId: TLShapeId; +}): TLShapeId | null => { + const selectedRelation = getAllRelations().find((r) => r.id === relationId); + if (!selectedRelation) return null; + + const sourceNode = editor.getShape(sourceId); + const targetNode = editor.getShape(targetId); + const { isReverse } = checkConnectionType( + selectedRelation, + sourceNode?.type ?? "", + targetNode?.type ?? "", + ); + const label = + isReverse && selectedRelation.complement + ? selectedRelation.complement + : selectedRelation.label; + + const sourceBounds = editor.getShapePageBounds(sourceId); + if (!sourceBounds) return null; + + const arrowId = createShapeId(); + editor.createShape({ + id: arrowId, + type: relationId, + x: sourceBounds.midX, + y: sourceBounds.midY, + props: { + color: getRelationColor(selectedRelation.label), + text: label, + dash: "draw", + size: "m", + fill: "none", + bend: 0, + start: { x: 0, y: 0 }, + end: { x: 0, y: 0 }, + arrowheadStart: "none", + arrowheadEnd: "arrow", + labelPosition: 0.5, + font: "draw", + scale: 1, + }, + }); + + const newArrow = editor.getShape(arrowId); + if (!newArrow) return null; + + createOrUpdateArrowBinding(editor, newArrow, sourceId, { + terminal: "start", + normalizedAnchor: { x: 0.5, y: 0.5 }, + isPrecise: false, + isExact: false, + }); + createOrUpdateArrowBinding(editor, newArrow, targetId, { + terminal: "end", + normalizedAnchor: { x: 0.5, y: 0.5 }, + isPrecise: false, + isExact: false, + }); + + const util = editor.getShapeUtil(newArrow); + if ( + util instanceof BaseDiscourseRelationUtil && + "handleCreateRelationsInRoam" in util + ) { + type UtilWithRoamPersistence = BaseDiscourseRelationUtil & { + handleCreateRelationsInRoam: (args: { + arrow: DiscourseRelationShape; + targetId: TLShapeId; + }) => Promise; + }; + void (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({ + arrow: editor.getShape(arrowId) ?? newArrow, + targetId, + }); + } + + editor.select(arrowId); + return arrowId; +}; diff --git a/apps/roam/src/components/canvas/uiOverrides.tsx b/apps/roam/src/components/canvas/uiOverrides.tsx index 38025d810..ed2e07035 100644 --- a/apps/roam/src/components/canvas/uiOverrides.tsx +++ b/apps/roam/src/components/canvas/uiOverrides.tsx @@ -1,7 +1,9 @@ import React, { ReactElement } from "react"; import { + TLArrowBinding, TLImageShape, TLShape, + TLShapeId, TLTextShape, TLUiDialogProps, TLUiOverrides, @@ -49,6 +51,11 @@ import { COLOR_ARRAY } from "./DiscourseNodeUtil"; import calcCanvasNodeSizeAndImg from "~/utils/calcCanvasNodeSizeAndImg"; import { AddReferencedNodeType } from "./DiscourseRelationShape/DiscourseRelationTool"; import { getRelationColor } from "./DiscourseRelationShape/DiscourseRelationUtil"; +import { + createRelationBetweenNodes, + getValidRelationTypesBetween, +} from "./overlays/relationCreation"; +import { isDiscourseNodeShape } from "./canvasUtils"; import DiscourseGraphPanel from "./DiscourseToolPanel"; import type { CanvasNodeShortcuts } from "~/components/settings/utils/zodSchema"; import { CustomDefaultToolbar } from "./CustomDefaultToolbar"; @@ -224,6 +231,27 @@ export const getOnSelectForShape = ({ return () => {}; }; +const getArrowBoundNodeIds = ( + editor: Editor, + arrow: TLShape, +): { startId: TLShapeId; endId: TLShapeId } | null => { + const bindings = editor.getBindingsFromShape(arrow, "arrow"); + const startId = bindings.find((b) => b.props.terminal === "start")?.toId; + const endId = bindings.find((b) => b.props.terminal === "end")?.toId; + if (!startId || !endId) return null; + + const startShape = editor.getShape(startId); + const endShape = editor.getShape(endId); + if (!startShape || !endShape) return null; + if ( + !isDiscourseNodeShape(editor, startShape) || + !isDiscourseNodeShape(editor, endShape) + ) + return null; + + return { startId, endId }; +}; + export const CustomContextMenu = ({ extensionAPI, allNodes, @@ -239,6 +267,22 @@ export const CustomContextMenu = ({ ); const isTextSelected = selectedShape?.type === "text"; const isImageSelected = selectedShape?.type === "image"; + const arrowRelationOptions = useValue( + "arrowRelationOptions", + () => { + if (!selectedShape || selectedShape.type !== "arrow") return null; + const boundNodes = getArrowBoundNodeIds(editor, selectedShape); + if (!boundNodes) return null; + const relationTypes = getValidRelationTypesBetween( + editor, + boundNodes.startId, + boundNodes.endId, + ); + if (relationTypes.length === 0) return null; + return { arrowId: selectedShape.id, ...boundNodes, relationTypes }; + }, + [editor, selectedShape], + ); return ( @@ -268,6 +312,30 @@ export const CustomContextMenu = ({ )} + {arrowRelationOptions && ( + + + {arrowRelationOptions.relationTypes.map((rt) => ( + { + const newArrowId = createRelationBetweenNodes({ + editor, + relationId: rt.id, + sourceId: arrowRelationOptions.startId, + targetId: arrowRelationOptions.endId, + }); + if (newArrowId) { + editor.deleteShapes([arrowRelationOptions.arrowId]); + } + }} + /> + ))} + + + )} ); }; From 0a799cc6159656bd91f1af26d6863816fb1da37d Mon Sep 17 00:00:00 2001 From: sid597 Date: Mon, 25 May 2026 15:21:53 +0530 Subject: [PATCH 2/5] ENG-1477: Keep original arrow until relation persistence succeeds Converting a plain arrow deleted the original immediately after the new relation shape was created, but persistence runs asynchronously and deletes the new arrow on failure (e.g. canvas page UID lookup fails), which could leave neither arrow. Await handleCreateRelationsInRoam and treat the new arrow surviving as success; the convert action now removes the original only when the relation persisted. The drag-handle caller voids the now-async helper (no original arrow to lose, behavior unchanged). --- .../canvas/overlays/DragHandleOverlay.tsx | 2 +- .../components/canvas/overlays/relationCreation.ts | 13 ++++++++----- apps/roam/src/components/canvas/uiOverrides.tsx | 4 ++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx b/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx index b6c3ec428..0c7cbc33f 100644 --- a/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx +++ b/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx @@ -249,7 +249,7 @@ export const DragHandleOverlay = () => { (relationId: string) => { if (!pending) return; - createRelationBetweenNodes({ + void createRelationBetweenNodes({ editor, relationId, sourceId: pending.sourceId, diff --git a/apps/roam/src/components/canvas/overlays/relationCreation.ts b/apps/roam/src/components/canvas/overlays/relationCreation.ts index 304bf4e6e..12256f887 100644 --- a/apps/roam/src/components/canvas/overlays/relationCreation.ts +++ b/apps/roam/src/components/canvas/overlays/relationCreation.ts @@ -57,7 +57,7 @@ export const getValidRelationTypesBetween = ( return validTypes; }; -export const createRelationBetweenNodes = ({ +export const createRelationBetweenNodes = async ({ editor, relationId, sourceId, @@ -67,7 +67,7 @@ export const createRelationBetweenNodes = ({ relationId: string; sourceId: TLShapeId; targetId: TLShapeId; -}): TLShapeId | null => { +}): Promise => { const selectedRelation = getAllRelations().find((r) => r.id === relationId); if (!selectedRelation) return null; @@ -125,6 +125,8 @@ export const createRelationBetweenNodes = ({ isExact: false, }); + editor.select(arrowId); + const util = editor.getShapeUtil(newArrow); if ( util instanceof BaseDiscourseRelationUtil && @@ -136,12 +138,13 @@ export const createRelationBetweenNodes = ({ targetId: TLShapeId; }) => Promise; }; - void (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({ + await (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({ arrow: editor.getShape(arrowId) ?? newArrow, targetId, }); } - editor.select(arrowId); - return arrowId; + // handleCreateRelationsInRoam deletes the new arrow if it rejects the + // conversion, so a surviving shape means the relation was persisted. + return editor.getShape(arrowId) ? arrowId : null; }; diff --git a/apps/roam/src/components/canvas/uiOverrides.tsx b/apps/roam/src/components/canvas/uiOverrides.tsx index ed2e07035..d78e2fa0c 100644 --- a/apps/roam/src/components/canvas/uiOverrides.tsx +++ b/apps/roam/src/components/canvas/uiOverrides.tsx @@ -320,8 +320,8 @@ export const CustomContextMenu = ({ key={rt.id} id={`relation-${rt.id}`} label={rt.label} - onSelect={() => { - const newArrowId = createRelationBetweenNodes({ + onSelect={async () => { + const newArrowId = await createRelationBetweenNodes({ editor, relationId: rt.id, sourceId: arrowRelationOptions.startId, From 42cce455e19003ab1f3cdbf03c89a305c5b07a72 Mon Sep 17 00:00:00 2001 From: sid597 Date: Thu, 28 May 2026 12:43:03 +0530 Subject: [PATCH 3/5] ENG-1477: Bail when bound nodes are missing in convert flow Drop the `sourceNode?.type ?? ""` lie that quietly proceeded past a deleted source/target shape and let the binding step fail silently, losing the original arrow with no toast. Return null early instead. Also drop the `?? newArrow` fallback when re-reading the shape before persistence; the read happens synchronously after createShape + bindings, so there is no scenario where it returns undefined that we should paper over. --- .../src/components/canvas/overlays/relationCreation.ts | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/apps/roam/src/components/canvas/overlays/relationCreation.ts b/apps/roam/src/components/canvas/overlays/relationCreation.ts index 12256f887..e7b0b72b8 100644 --- a/apps/roam/src/components/canvas/overlays/relationCreation.ts +++ b/apps/roam/src/components/canvas/overlays/relationCreation.ts @@ -73,10 +73,11 @@ export const createRelationBetweenNodes = async ({ const sourceNode = editor.getShape(sourceId); const targetNode = editor.getShape(targetId); + if (!sourceNode || !targetNode) return null; const { isReverse } = checkConnectionType( selectedRelation, - sourceNode?.type ?? "", - targetNode?.type ?? "", + sourceNode.type, + targetNode.type, ); const label = isReverse && selectedRelation.complement @@ -139,7 +140,7 @@ export const createRelationBetweenNodes = async ({ }) => Promise; }; await (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({ - arrow: editor.getShape(arrowId) ?? newArrow, + arrow: newArrow, targetId, }); } From dc41755943fc2ce2fdc27e769a52e5e391e4addf Mon Sep 17 00:00:00 2001 From: sid597 Date: Thu, 28 May 2026 13:50:03 +0530 Subject: [PATCH 4/5] ENG-1477: Preserve arrow geometry when converting --- .../canvas/overlays/DragHandleOverlay.tsx | 4 +- .../canvas/overlays/relationCreation.ts | 2 +- .../src/components/canvas/uiOverrides.tsx | 167 +++++++++++++++--- 3 files changed, 150 insertions(+), 23 deletions(-) diff --git a/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx b/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx index 0c7cbc33f..f85835a97 100644 --- a/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx +++ b/apps/roam/src/components/canvas/overlays/DragHandleOverlay.tsx @@ -7,7 +7,7 @@ import { } from "~/components/canvas/canvasUtils"; import { dispatchToastEvent } from "~/components/canvas/ToastListener"; import { RelationTypeDropdown } from "./RelationTypeDropdown"; -import { createRelationBetweenNodes } from "./relationCreation"; +import { createDefaultRelationBetweenNodes } from "./relationCreation"; const HANDLE_RADIUS = 5; const HANDLE_HIT_AREA = 12; @@ -249,7 +249,7 @@ export const DragHandleOverlay = () => { (relationId: string) => { if (!pending) return; - void createRelationBetweenNodes({ + void createDefaultRelationBetweenNodes({ editor, relationId, sourceId: pending.sourceId, diff --git a/apps/roam/src/components/canvas/overlays/relationCreation.ts b/apps/roam/src/components/canvas/overlays/relationCreation.ts index e7b0b72b8..d73395531 100644 --- a/apps/roam/src/components/canvas/overlays/relationCreation.ts +++ b/apps/roam/src/components/canvas/overlays/relationCreation.ts @@ -57,7 +57,7 @@ export const getValidRelationTypesBetween = ( return validTypes; }; -export const createRelationBetweenNodes = async ({ +export const createDefaultRelationBetweenNodes = async ({ editor, relationId, sourceId, diff --git a/apps/roam/src/components/canvas/uiOverrides.tsx b/apps/roam/src/components/canvas/uiOverrides.tsx index d78e2fa0c..0e87a3510 100644 --- a/apps/roam/src/components/canvas/uiOverrides.tsx +++ b/apps/roam/src/components/canvas/uiOverrides.tsx @@ -1,6 +1,7 @@ import React, { ReactElement } from "react"; import { TLArrowBinding, + TLArrowShape, TLImageShape, TLShape, TLShapeId, @@ -50,12 +51,18 @@ import { formatHexColor } from "~/components/settings/DiscourseNodeCanvasSetting import { COLOR_ARRAY } from "./DiscourseNodeUtil"; import calcCanvasNodeSizeAndImg from "~/utils/calcCanvasNodeSizeAndImg"; import { AddReferencedNodeType } from "./DiscourseRelationShape/DiscourseRelationTool"; -import { getRelationColor } from "./DiscourseRelationShape/DiscourseRelationUtil"; import { - createRelationBetweenNodes, - getValidRelationTypesBetween, -} from "./overlays/relationCreation"; -import { isDiscourseNodeShape } from "./canvasUtils"; + BaseDiscourseRelationUtil, + DiscourseRelationShape, + getRelationColor, +} from "./DiscourseRelationShape/DiscourseRelationUtil"; +import { getValidRelationTypesBetween } from "./overlays/relationCreation"; +import { + checkConnectionType, + getAllRelations, + isDiscourseNodeShape, +} from "./canvasUtils"; +import { createOrUpdateArrowBinding } from "./DiscourseRelationShape/helpers"; import DiscourseGraphPanel from "./DiscourseToolPanel"; import type { CanvasNodeShortcuts } from "~/components/settings/utils/zodSchema"; import { CustomDefaultToolbar } from "./CustomDefaultToolbar"; @@ -231,17 +238,24 @@ export const getOnSelectForShape = ({ return () => {}; }; -const getArrowBoundNodeIds = ( +type ArrowBoundNodeInfo = { + startId: TLShapeId; + endId: TLShapeId; + startBinding: TLArrowBinding; + endBinding: TLArrowBinding; +}; + +const getArrowBoundNodeInfo = ( editor: Editor, arrow: TLShape, -): { startId: TLShapeId; endId: TLShapeId } | null => { +): ArrowBoundNodeInfo | null => { const bindings = editor.getBindingsFromShape(arrow, "arrow"); - const startId = bindings.find((b) => b.props.terminal === "start")?.toId; - const endId = bindings.find((b) => b.props.terminal === "end")?.toId; - if (!startId || !endId) return null; + const startBinding = bindings.find((b) => b.props.terminal === "start"); + const endBinding = bindings.find((b) => b.props.terminal === "end"); + if (!startBinding || !endBinding) return null; - const startShape = editor.getShape(startId); - const endShape = editor.getShape(endId); + const startShape = editor.getShape(startBinding.toId); + const endShape = editor.getShape(endBinding.toId); if (!startShape || !endShape) return null; if ( !isDiscourseNodeShape(editor, startShape) || @@ -249,7 +263,119 @@ const getArrowBoundNodeIds = ( ) return null; - return { startId, endId }; + return { + startId: startBinding.toId, + endId: endBinding.toId, + startBinding, + endBinding, + }; +}; + +const copyArrowBindingProps = ( + binding: TLArrowBinding, +): TLArrowBinding["props"] => ({ + ...binding.props, + normalizedAnchor: { ...binding.props.normalizedAnchor }, +}); + +const convertArrowToRelation = async ({ + editor, + arrow, + relationId, +}: { + editor: Editor; + arrow: TLArrowShape; + relationId: string; +}): Promise => { + const boundNodes = getArrowBoundNodeInfo(editor, arrow); + if (!boundNodes) return null; + + const selectedRelation = getAllRelations().find((r) => r.id === relationId); + if (!selectedRelation) return null; + + const sourceNode = editor.getShape(boundNodes.startId); + const targetNode = editor.getShape(boundNodes.endId); + if (!sourceNode || !targetNode) return null; + + const { isReverse } = checkConnectionType( + selectedRelation, + sourceNode.type, + targetNode.type, + ); + const label = + isReverse && selectedRelation.complement + ? selectedRelation.complement + : selectedRelation.label; + const relationColor = getRelationColor(selectedRelation.label); + const arrowProps = structuredClone(arrow.props); + const relationArrowId = createShapeId(); + + editor.createShape({ + id: relationArrowId, + type: relationId, + parentId: arrow.parentId, + x: arrow.x, + y: arrow.y, + rotation: arrow.rotation, + opacity: arrow.opacity, + isLocked: arrow.isLocked, + meta: { ...arrow.meta }, + props: { + ...arrowProps, + color: relationColor, + labelColor: relationColor, + text: label, + }, + }); + + const relationArrow = + editor.getShape(relationArrowId); + if (!relationArrow) return null; + + createOrUpdateArrowBinding( + editor, + relationArrow, + boundNodes.startId, + copyArrowBindingProps(boundNodes.startBinding), + ); + createOrUpdateArrowBinding( + editor, + relationArrow, + boundNodes.endId, + copyArrowBindingProps(boundNodes.endBinding), + ); + + const util = editor.getShapeUtil(relationArrow); + if ( + util instanceof BaseDiscourseRelationUtil && + "handleCreateRelationsInRoam" in util + ) { + type UtilWithRoamPersistence = BaseDiscourseRelationUtil & { + handleCreateRelationsInRoam: (args: { + arrow: DiscourseRelationShape; + targetId: TLShapeId; + }) => Promise; + }; + await (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({ + arrow: relationArrow, + targetId: boundNodes.endId, + }); + } + + const persistedArrow = + editor.getShape(relationArrowId); + if (!persistedArrow) { + editor.select(arrow.id); + return null; + } + + editor.deleteShapes([arrow.id]); + editor.updateShapes([ + { id: persistedArrow.id, type: persistedArrow.type, index: arrow.index }, + ]); + editor.select(relationArrowId); + + return relationArrowId; }; export const CustomContextMenu = ({ @@ -271,7 +397,7 @@ export const CustomContextMenu = ({ "arrowRelationOptions", () => { if (!selectedShape || selectedShape.type !== "arrow") return null; - const boundNodes = getArrowBoundNodeIds(editor, selectedShape); + const boundNodes = getArrowBoundNodeInfo(editor, selectedShape); if (!boundNodes) return null; const relationTypes = getValidRelationTypesBetween( editor, @@ -321,15 +447,16 @@ export const CustomContextMenu = ({ id={`relation-${rt.id}`} label={rt.label} onSelect={async () => { - const newArrowId = await createRelationBetweenNodes({ + const arrow = editor.getShape( + arrowRelationOptions.arrowId, + ); + if (!arrow || arrow.type !== "arrow") return; + + await convertArrowToRelation({ editor, + arrow, relationId: rt.id, - sourceId: arrowRelationOptions.startId, - targetId: arrowRelationOptions.endId, }); - if (newArrowId) { - editor.deleteShapes([arrowRelationOptions.arrowId]); - } }} /> ))} From 50153011195a6cf7d901f9d8bd99e8687448b8c7 Mon Sep 17 00:00:00 2001 From: sid597 Date: Fri, 29 May 2026 13:01:09 +0530 Subject: [PATCH 5/5] ENG-1477: Extract relation helpers and enforce DG visual contract on conversion Pull the directional-label rule and the Roam-persistence call into shared getDirectionalRelationLabel and persistRelationArrow helpers, used by both the drag-handle creator and the conversion path (the label rule previously lived in three places). Convert now sets the relation's visual contract explicitly instead of spreading the original arrow's props: it inherits only geometry (bend, start, end, labelPosition) and forces color, label, arrowheads, dash, size, font, scale, fill. Arrowheads encode relation direction, so a converted double- or reverse-headed arrow would otherwise render a misdirected relation; explicit props also stop future tldraw arrow props from leaking into converted relations. --- .../canvas/overlays/relationCreation.ts | 90 +++++++++++++------ .../src/components/canvas/uiOverrides.tsx | 59 ++++++------ 2 files changed, 89 insertions(+), 60 deletions(-) diff --git a/apps/roam/src/components/canvas/overlays/relationCreation.ts b/apps/roam/src/components/canvas/overlays/relationCreation.ts index d73395531..cfbe618d0 100644 --- a/apps/roam/src/components/canvas/overlays/relationCreation.ts +++ b/apps/roam/src/components/canvas/overlays/relationCreation.ts @@ -15,9 +15,61 @@ import { getAllRelations, isDiscourseNodeShape, } from "~/components/canvas/canvasUtils"; +import type { DiscourseRelation } from "~/utils/getDiscourseRelations"; type RelationTypeOption = { id: string; label: string; color: string }; +type DirectionalRelation = Pick< + DiscourseRelation, + "label" | "complement" | "source" | "destination" +>; + +export const getDirectionalRelationLabel = ({ + relation, + sourceNodeType, + targetNodeType, +}: { + relation: DirectionalRelation; + sourceNodeType: string; + targetNodeType: string; +}): string => { + const { isReverse } = checkConnectionType( + relation, + sourceNodeType, + targetNodeType, + ); + return isReverse && relation.complement + ? relation.complement + : relation.label; +}; + +export const persistRelationArrow = async ({ + editor, + arrow, + targetId, +}: { + editor: Editor; + arrow: DiscourseRelationShape; + targetId: TLShapeId; +}): Promise => { + const util = editor.getShapeUtil(arrow); + if ( + util instanceof BaseDiscourseRelationUtil && + "handleCreateRelationsInRoam" in util + ) { + type UtilWithRoamPersistence = BaseDiscourseRelationUtil & { + handleCreateRelationsInRoam: (args: { + arrow: DiscourseRelationShape; + targetId: TLShapeId; + }) => Promise; + }; + await (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({ + arrow, + targetId, + }); + } +}; + export const getValidRelationTypesBetween = ( editor: Editor, startId: TLShapeId, @@ -44,8 +96,11 @@ export const getValidRelationTypesBetween = ( ); if (!isDirect && !isReverse) continue; - const label = - isReverse && relation.complement ? relation.complement : relation.label; + const label = getDirectionalRelationLabel({ + relation, + sourceNodeType: startNode.type, + targetNodeType: endNode.type, + }); if (seenLabels.has(label)) continue; seenLabels.add(label); @@ -74,15 +129,11 @@ export const createDefaultRelationBetweenNodes = async ({ const sourceNode = editor.getShape(sourceId); const targetNode = editor.getShape(targetId); if (!sourceNode || !targetNode) return null; - const { isReverse } = checkConnectionType( - selectedRelation, - sourceNode.type, - targetNode.type, - ); - const label = - isReverse && selectedRelation.complement - ? selectedRelation.complement - : selectedRelation.label; + const label = getDirectionalRelationLabel({ + relation: selectedRelation, + sourceNodeType: sourceNode.type, + targetNodeType: targetNode.type, + }); const sourceBounds = editor.getShapePageBounds(sourceId); if (!sourceBounds) return null; @@ -128,22 +179,7 @@ export const createDefaultRelationBetweenNodes = async ({ editor.select(arrowId); - const util = editor.getShapeUtil(newArrow); - if ( - util instanceof BaseDiscourseRelationUtil && - "handleCreateRelationsInRoam" in util - ) { - type UtilWithRoamPersistence = BaseDiscourseRelationUtil & { - handleCreateRelationsInRoam: (args: { - arrow: DiscourseRelationShape; - targetId: TLShapeId; - }) => Promise; - }; - await (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({ - arrow: newArrow, - targetId, - }); - } + await persistRelationArrow({ editor, arrow: newArrow, targetId }); // handleCreateRelationsInRoam deletes the new arrow if it rejects the // conversion, so a surviving shape means the relation was persisted. diff --git a/apps/roam/src/components/canvas/uiOverrides.tsx b/apps/roam/src/components/canvas/uiOverrides.tsx index 0e87a3510..77f73e19c 100644 --- a/apps/roam/src/components/canvas/uiOverrides.tsx +++ b/apps/roam/src/components/canvas/uiOverrides.tsx @@ -52,16 +52,15 @@ import { COLOR_ARRAY } from "./DiscourseNodeUtil"; import calcCanvasNodeSizeAndImg from "~/utils/calcCanvasNodeSizeAndImg"; import { AddReferencedNodeType } from "./DiscourseRelationShape/DiscourseRelationTool"; import { - BaseDiscourseRelationUtil, DiscourseRelationShape, getRelationColor, } from "./DiscourseRelationShape/DiscourseRelationUtil"; -import { getValidRelationTypesBetween } from "./overlays/relationCreation"; import { - checkConnectionType, - getAllRelations, - isDiscourseNodeShape, -} from "./canvasUtils"; + getDirectionalRelationLabel, + getValidRelationTypesBetween, + persistRelationArrow, +} from "./overlays/relationCreation"; +import { getAllRelations, isDiscourseNodeShape } from "./canvasUtils"; import { createOrUpdateArrowBinding } from "./DiscourseRelationShape/helpers"; import DiscourseGraphPanel from "./DiscourseToolPanel"; import type { CanvasNodeShortcuts } from "~/components/settings/utils/zodSchema"; @@ -297,17 +296,12 @@ const convertArrowToRelation = async ({ const targetNode = editor.getShape(boundNodes.endId); if (!sourceNode || !targetNode) return null; - const { isReverse } = checkConnectionType( - selectedRelation, - sourceNode.type, - targetNode.type, - ); - const label = - isReverse && selectedRelation.complement - ? selectedRelation.complement - : selectedRelation.label; + const label = getDirectionalRelationLabel({ + relation: selectedRelation, + sourceNodeType: sourceNode.type, + targetNodeType: targetNode.type, + }); const relationColor = getRelationColor(selectedRelation.label); - const arrowProps = structuredClone(arrow.props); const relationArrowId = createShapeId(); editor.createShape({ @@ -321,7 +315,17 @@ const convertArrowToRelation = async ({ isLocked: arrow.isLocked, meta: { ...arrow.meta }, props: { - ...arrowProps, + bend: arrow.props.bend, + start: structuredClone(arrow.props.start), + end: structuredClone(arrow.props.end), + labelPosition: arrow.props.labelPosition, + dash: "draw", + size: "m", + fill: "none", + arrowheadStart: "none", + arrowheadEnd: "arrow", + font: "draw", + scale: 1, color: relationColor, labelColor: relationColor, text: label, @@ -345,22 +349,11 @@ const convertArrowToRelation = async ({ copyArrowBindingProps(boundNodes.endBinding), ); - const util = editor.getShapeUtil(relationArrow); - if ( - util instanceof BaseDiscourseRelationUtil && - "handleCreateRelationsInRoam" in util - ) { - type UtilWithRoamPersistence = BaseDiscourseRelationUtil & { - handleCreateRelationsInRoam: (args: { - arrow: DiscourseRelationShape; - targetId: TLShapeId; - }) => Promise; - }; - await (util as UtilWithRoamPersistence).handleCreateRelationsInRoam({ - arrow: relationArrow, - targetId: boundNodes.endId, - }); - } + await persistRelationArrow({ + editor, + arrow: relationArrow, + targetId: boundNodes.endId, + }); const persistedArrow = editor.getShape(relationArrowId);