packages/remirror__core-utils/src/prosemirror-utils.ts
import { ErrorConstant } from '@remirror/core-constants';
import {
entries,
invariant,
isArray,
isEmptyObject,
isNonEmptyArray,
isNullOrUndefined,
isString,
object,
} from '@remirror/core-helpers';
import type {
AnyFunction,
AttributesProps,
EditorSchema,
EditorState,
EditorView,
Fragment,
KeyBindingCommandFunction,
KeyBindings,
Mark,
MarkTypesProps,
NodeTypeProps,
NodeTypesProps,
OptionalMarkProps,
OptionalProsemirrorNodeProps,
PosProps,
ProsemirrorAttributes,
ProsemirrorCommandFunction,
ProsemirrorKeyBindings,
ProsemirrorNode,
ProsemirrorNodeProps,
ResolvedPos,
Selection,
SelectionProps,
Transaction,
TransactionProps,
} from '@remirror/core-types';
import type { MarkSpec, NodeSpec, NodeType } from '@remirror/pm/model';
import { Selection as PMSelection } from '@remirror/pm/state';
import { isEditorState, isNodeSelection, isResolvedPos, isSelection } from './core-utils';
import { isTextDomNode } from './dom-utils';
interface NodeEqualsTypeProps extends NodeTypesProps, OptionalProsemirrorNodeProps {}
/**
* Checks if the type a given `node` has a given `nodeType`.
*/
export function isNodeOfType(props: NodeEqualsTypeProps): boolean {
const { types, node } = props;
if (!node) {
return false;
}
const matches = (type: NodeType | string) => type === node.type || type === node.type.name;
if (isArray(types)) {
return types.some(matches);
}
return matches(types);
}
interface MarkEqualsTypeProps extends MarkTypesProps, OptionalMarkProps {}
/**
* Creates a new transaction object from a given transaction. This is useful
* when applying changes to a transaction, that you may want to rollback.
*
* ```ts
* function() applyUpdateIfValid(state: EditorState) {
* const tr = cloneTransaction(state.tr);
*
* tr.insertText('hello');
*
* if (!checkValid(tr)) {
* return;
* }
*
* applyClonedTransaction({ clone: tr, tr: state.tr });
* }
* ```
*
* The above example applies a transaction to the cloned transaction then checks
* to see if the changes are still valid and if they are applies the mutative
* changes to the original state transaction.
*
* @param tr - the prosemirror transaction
*/
export function cloneTransaction(tr: Transaction): Transaction {
return Object.assign(Object.create(tr), tr).setTime(Date.now());
}
interface ApplyClonedTransactionProps extends TransactionProps {
/**
* The clone.
*/
clone: Transaction;
}
/**
* Get the diff between two ordered arrays with a reference equality check.
*/
function diff<Type>(primary: Type[], other: Type[]): Type[] {
return primary.filter((item, index) => item !== other[index]);
}
/**
* Apply the steps of a cloned transaction to the original transaction `tr`.
*/
export function applyClonedTransaction(props: ApplyClonedTransactionProps): void {
const { clone, tr } = props;
const steps = diff(clone.steps, tr.steps);
for (const step of steps) {
tr.step(step);
}
}
/**
* Returns a new transaction by combining all steps of the passed transactions onto the previous state
*/
export function composeTransactionSteps(
transactions: readonly Transaction[],
oldState: EditorState,
): Transaction {
const { tr } = oldState;
transactions.forEach((transaction) => {
transaction.steps.forEach((step) => {
tr.step(step);
});
});
return tr;
}
/**
* Checks if the type a given `node` has a given `nodeType`.
*/
export function markEqualsType(props: MarkEqualsTypeProps): boolean {
const { types, mark } = props;
return mark ? (Array.isArray(types) && types.includes(mark.type)) || mark.type === types : false;
}
interface RemoveNodeAtPositionProps extends TransactionProps, PosProps {}
/**
* Performs a `delete` transaction that removes a node at a given position with
* the given `node`. `position` should point at the position immediately before
* the node.
*
* @param position - the prosemirror position
*/
export function removeNodeAtPosition({ pos, tr }: RemoveNodeAtPositionProps): Transaction {
const node = tr.doc.nodeAt(pos);
if (node) {
tr.delete(pos, pos + node.nodeSize);
}
return tr;
}
interface ReplaceNodeAtPositionProps extends RemoveNodeAtPositionProps {
content: Fragment | ProsemirrorNode | ProsemirrorNode[];
}
/**
* Replaces the node at the provided position with the provided content.
*/
export function replaceNodeAtPosition({
pos,
tr,
content,
}: ReplaceNodeAtPositionProps): Transaction {
const node = tr.doc.nodeAt(pos);
if (node) {
tr.replaceWith(pos, pos + node.nodeSize, content);
}
return tr;
}
/**
* Returns DOM reference of a node at a given `position`.
*
* @remarks
*
* If the node type is of type `TEXT_NODE` it will return the reference of the
* parent node.
*
* A simple use case
*
* ```ts
* const element = findElementAtPosition($from.pos, view);
* ```
*
* @param position - the prosemirror position
* @param view - the editor view
*/
export function findElementAtPosition(position: number, view: EditorView): HTMLElement {
const dom = view.domAtPos(position);
const node = dom.node.childNodes[dom.offset];
if (isTextDomNode(dom.node)) {
return dom.node.parentNode as HTMLElement;
}
if (isNullOrUndefined(node) || isTextDomNode(node)) {
return dom.node as HTMLElement;
}
return node as HTMLElement;
}
/**
* Iterates over parent nodes, returning the closest node and its start position
* that the `predicate` returns truthy for. `start` points to the start position
* of the node, `pos` points directly before the node.
*
* ```ts
* const predicate = node => node.type === schema.nodes.blockquote;
* const parent = findParentNode({ predicate, selection });
* ```
*/
export function findParentNode(props: FindParentNodeProps): FindProsemirrorNodeResult | undefined {
const { predicate, selection } = props;
const $pos = isEditorState(selection)
? selection.selection.$from
: isSelection(selection)
? selection.$from
: selection;
for (let depth = $pos.depth; depth > 0; depth--) {
const node = $pos.node(depth);
const pos = depth > 0 ? $pos.before(depth) : 0;
const start = $pos.start(depth);
const end = pos + node.nodeSize;
if (predicate(node, pos)) {
return { pos, depth, node, start, end };
}
}
return;
}
/**
* Finds the node at the resolved position.
*
* @param $pos - the resolve position in the document
*/
export function findNodeAtPosition($pos: ResolvedPos): FindProsemirrorNodeResult {
const { depth } = $pos;
const pos = depth > 0 ? $pos.before(depth) : 0;
const node = $pos.node(depth);
const start = $pos.start(depth);
const end = pos + node.nodeSize;
return { pos, start, node, end, depth };
}
/**
* Finds the node at the passed selection.
*/
export function findNodeAtSelection(selection: Selection): FindProsemirrorNodeResult {
const parentNode = findParentNode({ predicate: () => true, selection });
invariant(parentNode, { message: 'No parent node found for the selection provided.' });
return parentNode;
}
interface FindParentNodeOfTypeProps extends NodeTypesProps, StateSelectionPosProps {}
/**
* Iterates over parent nodes, returning closest node of a given `nodeType`.
* `start` points to the start position of the node, `pos` points directly
* before the node.
*
* ```ts
* const parent = findParentNodeOfType({types: schema.nodes.paragraph, selection});
* ```
*/
export function findParentNodeOfType(
props: FindParentNodeOfTypeProps,
): FindProsemirrorNodeResult | undefined {
const { types, selection } = props;
return findParentNode({ predicate: (node) => isNodeOfType({ types, node }), selection });
}
/**
* Returns position of the previous node.
*
* ```ts
* const pos = findPositionOfNodeBefore(tr.selection);
* ```
*
* @param selection - the prosemirror selection
*
* @deprecated This util is hard to use and not that useful
*/
export function findPositionOfNodeBefore(
value: Selection | ResolvedPos | EditorState | Transaction,
): FindProsemirrorNodeResult | undefined {
const $pos = isResolvedPos(value)
? value
: isSelection(value)
? value.$from
: value.selection.$from;
if (isNullOrUndefined($pos)) {
throw new Error('Invalid value passed in.');
}
const { nodeBefore } = $pos;
const selection = PMSelection.findFrom($pos, -1);
if (!selection || !nodeBefore) {
return;
}
const parent = findParentNodeOfType({ types: nodeBefore.type, selection });
return parent
? parent
: {
node: nodeBefore,
pos: selection.$from.pos,
end: selection.$from.end(),
depth: selection.$from.depth + 1,
start: selection.$from.start(selection.$from.depth + 1),
};
}
/**
* Updates the provided transaction to remove the node before.
*
* ```ts
* dispatch(
* removeNodeBefore(state.tr)
* );
* ```
*
* @param tr
*
* @deprecated This util is hard to use and not that useful
*/
export function removeNodeBefore(tr: Transaction): Transaction {
const result = findPositionOfNodeBefore(tr.selection);
if (result) {
removeNodeAtPosition({ pos: result.pos, tr });
}
return tr;
}
interface FindSelectedNodeOfTypeProps extends NodeTypesProps, SelectionProps {}
/**
* Returns a node of a given `nodeType` if it is selected. `start` points to the
* start position of the node, `pos` points directly before the node.
*
* ```ts
* const { extension, inlineExtension, bodiedExtension } = schema.nodes;
*
* const selectedNode = findSelectedNodeOfType({
* types: [extension, inlineExtension, bodiedExtension],
* selection,
* });
* ```
*/
export function findSelectedNodeOfType(
props: FindSelectedNodeOfTypeProps,
): FindProsemirrorNodeResult | undefined {
const { types, selection } = props;
if (!isNodeSelection(selection) || !isNodeOfType({ types, node: selection.node })) {
return;
}
return {
pos: selection.$from.pos,
depth: selection.$from.depth,
start: selection.$from.start(),
end: selection.$from.pos + selection.node.nodeSize,
node: selection.node,
};
}
export interface FindProsemirrorNodeResult extends ProsemirrorNodeProps {
/**
* The start position of the node.
*/
start: number;
/**
* The end position of the node.
*/
end: number;
/**
* Points to position directly before the node.
*/
pos: number;
/**
* The depth the node. Equal to 0 if node is the root.
*/
depth: number;
}
interface StateSelectionPosProps {
/**
* Provide an editor state, or the editor selection or a resolved position.
*/
selection: EditorState | Selection | ResolvedPos;
}
interface FindParentNodeProps extends StateSelectionPosProps {
predicate: (node: ProsemirrorNode, pos: number) => boolean;
}
/**
* Returns the position of the node after the current position, selection or
* state.
*
* ```ts
* const pos = findPositionOfNodeBefore(tr.selection);
* ```
*
* @param selection - the prosemirror selection
*
* @deprecated This util is hard to use and not that useful
*/
export function findPositionOfNodeAfter(
value: Selection | ResolvedPos | EditorState,
): FindProsemirrorNodeResult | undefined {
const $pos = isResolvedPos(value)
? value
: isSelection(value)
? value.$from
: value.selection.$from;
if (isNullOrUndefined($pos)) {
throw new Error('Invalid value passed in.');
}
const { nodeAfter } = $pos;
const selection = PMSelection.findFrom($pos, 1);
if (!selection || !nodeAfter) {
return;
}
const parent = findParentNodeOfType({ types: nodeAfter.type, selection });
return parent
? parent
: {
node: nodeAfter,
pos: selection.$from.pos,
end: selection.$from.end(),
depth: selection.$from.depth + 1,
start: selection.$from.start(selection.$from.depth + 1),
};
}
/**
* Update the transaction to delete the node after the current selection.
*
* ```ts
* dispatch(removeNodeBefore(state.tr));
* ```
*
* @param tr
*
* @deprecated This util is hard to use and not that useful
*/
export function removeNodeAfter(tr: Transaction): Transaction {
const result = findPositionOfNodeAfter(tr.selection);
if (result) {
removeNodeAtPosition({ pos: result.pos, tr });
}
return tr;
}
/**
* Checks whether the selection or state is currently empty.
*
* @param value - the transaction selection or state
*/
export function isSelectionEmpty(value: Transaction | EditorState | Selection): boolean {
return isSelection(value) ? value.empty : value.selection.empty;
}
/**
* Check to see if a transaction has changed either the document or the current
* selection.
*
* @param tr - the transaction to check
*/
export function hasTransactionChanged(tr: Transaction): boolean {
return tr.docChanged || tr.selectionSet;
}
/**
* Checks whether the node type passed in is active within the region. Used by
* extensions to implement the `active` method.
*
* To ignore `attrs` just leave the attrs object empty or undefined.
*
* @param props - see [[`GetActiveAttrsProps`]]
*/
export function isNodeActive(props: GetActiveAttrsProps): boolean {
return !!getActiveNode(props);
}
interface GetActiveAttrsProps extends NodeTypeProps, Partial<AttributesProps> {
/**
* State or transaction parameter.
*/
state: EditorState | Transaction;
}
/**
* Get node of a provided type with the provided attributes if it exists as a
* parent. Returns positional data for the node that was found.
*/
export function getActiveNode(props: GetActiveAttrsProps): FindProsemirrorNodeResult | undefined {
const { state, type, attrs } = props;
const { selection, doc } = state;
const nodeType = isString(type) ? doc.type.schema.nodes[type] : type;
invariant(nodeType, { code: ErrorConstant.SCHEMA, message: `No node exists for ${type}` });
const active =
findSelectedNodeOfType({ selection, types: type }) ??
findParentNode({ predicate: (node: ProsemirrorNode) => node.type === nodeType, selection });
if (!attrs || isEmptyObject(attrs) || !active) {
return active;
}
return active.node.hasMarkup(nodeType, { ...active.node.attrs, ...attrs }) ? active : undefined;
}
/**
* The ProseMirror `Schema` as a JSON object.
*/
export interface SchemaJSON<Nodes extends string = string, Marks extends string = string> {
/**
* The nodes of the schema.
*/
nodes: Record<Nodes, NodeSpec>;
/**
* The marks within the schema.
*/
marks: Record<Marks, MarkSpec>;
}
/**
* Converts a `schema` to a JSON compatible object.
*/
export function schemaToJSON<Nodes extends string = string, Marks extends string = string>(
schema: EditorSchema,
): SchemaJSON<Nodes, Marks> {
const nodes: SchemaJSON['nodes'] = object();
const marks: SchemaJSON['marks'] = object();
for (const [key, { spec }] of entries(schema.nodes)) {
nodes[key] = spec;
}
for (const [key, { spec }] of entries(schema.marks)) {
marks[key] = spec;
}
return {
nodes,
marks,
};
}
/**
* Chains together keybindings, allowing for the same key binding to be used
* across multiple extensions without overriding behavior.
*
* @remarks
*
* When `next` is called it hands over full control of the keybindings to the
* function that invokes it.
*/
export function chainKeyBindingCommands(
...commands: KeyBindingCommandFunction[]
): KeyBindingCommandFunction {
return (props) => {
// When no commands are passed just ignore and continue.
if (!isNonEmptyArray(commands)) {
return false;
}
const [command, ...rest] = commands;
// Keeps track of whether the `next` method has been called. If it has been
// called we return the result and skip the rest of the downstream commands.
let calledNext = false;
/**
* Create the next function call. Updates the outer closure when the next
* method has been called.
*/
const createNext =
(...nextCommands: KeyBindingCommandFunction[]): (() => boolean) =>
() => {
// If there are no commands then this can be ignored and continued.
if (!isNonEmptyArray(nextCommands)) {
return false;
}
// Update the closure with information that the next method was invoked by
// this command.
calledNext = true;
const [, ...nextRest] = nextCommands;
// Recursively call the key bindings method.
return chainKeyBindingCommands(...nextCommands)({
...props,
next: createNext(...nextRest),
});
};
const next = createNext(...rest);
const exitEarly = command({ ...props, next });
// Exit the chain of commands early if either:
// - a) next was called
// - b) the command returned true
if (calledNext || exitEarly) {
return exitEarly;
}
// Continue to the next function in the chain of commands.
return next();
};
}
/**
* Used to merge key bindings together in a sensible way. Identical key bindings
* likely have the same key. as a result a naive merge would result in the
* binding added later in the merge being the only one the editor sees.
*
* This creator is used to create a merge that steps from the highest priority
* to the lowest priority giving each keybinding in the chain an opportunity to
* be run, and defer to the next command in the chain or choose not to.
*
* - It is used to create the [[`mergeKeyBindings`]] function helper.
* - It is used to create the [[`mergeProsemirrorKeyBindings`]] function helper.
*
* @typeParam [Schema] - the schema that is being used to create this command.
* @typeParam [Type] - the mapper type signature which is what the `mapper`
* param transforms the [[`KeyBindingCommandFunction`]] into.
*
* @param extensionKeymaps - the list of extension keymaps similar to the
* following:
* ```ts
* [{ Enter: () => false}, { Escape: () => true }, { Enter: () => true }]
* ```
* @param mapper - used to convert the [[`KeyBindingCommandFunction`]] into a
* function with a different signature. It's application can be seen in
* [[`mergeKeyBindings`]] and [[`mergeProsemirrorKeyBindings`]].
*
*/
function mergeKeyBindingCreator<Mapper extends AnyFunction = KeyBindingCommandFunction>(
extensionKeymaps: KeyBindings[],
mapper: (command: KeyBindingCommandFunction) => Mapper,
): Record<string, Mapper> {
// Keep track of the previous commands as we loop through the `extensionKeymaps`.
const previousCommandsMap = new Map<string, KeyBindingCommandFunction[]>();
// This is the combined mapping of commands. Essentially this function turns
// the `extensionKeymaps` array into a single object `extensionKeymap` which
// composes each function to give full control to the developer.
const mappedCommands: Record<string, Mapper> = object();
// Outer loop iterates over each object keybinding.
for (const extensionKeymap of extensionKeymaps) {
// Inner loop checks each keybinding on the keybinding object. `key` refers
// to the name of the keyboard combination, like `Shift-Enter` or
// `Cmd-Escape`.
for (const [key, newCommand] of entries(extensionKeymap)) {
// Get the previous commands for this key if it already exists
const previousCommands: KeyBindingCommandFunction[] = previousCommandsMap.get(key) ?? [];
// Update the commands array. This will be added to the
// `previousCommandsMap` to track the current keyboard combination.
const commands = [...previousCommands, newCommand];
// Chain the keyboard binding so that you have all the niceties, like
// being able to call `next` to run the remaining commands in the chain.
const command = chainKeyBindingCommands(...commands);
// Update the previous commands with the new commands that are now being used.
previousCommandsMap.set(key, commands);
// Store a copy of the mapped commands. If this was the last time this
// loop ran, then this is the command that would be called when a users
// enters the keyboard combination specified by the `key` in this context.
mappedCommands[key] = mapper(command);
}
}
return mappedCommands;
}
/**
* This merges an array of keybindings into one keybinding with the priority
* given to the items earlier in the array. `index: 0` has priority over `index:
* 1` which has priority over `index: 2` and so on.
*
* This is for use on remirror keybindings. See `mergeProsemirrorKeyBindings`
* for transforming the methods into `ProsemirrorCommandFunction`'s.
*/
export function mergeKeyBindings(extensionKeymaps: KeyBindings[]): KeyBindings {
return mergeKeyBindingCreator(extensionKeymaps, (command) => command);
}
/**
* This merges an array of keybindings into one keybinding with the priority
* given to the items earlier in the array. `index: 0` has priority over `index:
* 1` which has priority over `index: 2` and so on.
*
* This supports the [[ProsemirrorCommandFunction]] type signature where the
* `state`, `dispatch` and `view` are passed as separate arguments.
*/
export function mergeProsemirrorKeyBindings(
extensionKeymaps: KeyBindings[],
): ProsemirrorKeyBindings {
return mergeKeyBindingCreator(
extensionKeymaps,
// Convert the command to have a signature of the
// [[`ProsemirrorCommandFunction`]].
(command): ProsemirrorCommandFunction =>
(state, dispatch, view) =>
command({ state, dispatch, view, tr: state.tr, next: () => false }),
);
}
/**
* Determines if a Node or Mark contains the given attributes in its attributes set
*
* @param nodeOrMark - The Node or Mark to check
* @param attrs - The set of attributes it must contain
*/
export function containsAttributes(
nodeOrMark: ProsemirrorNode | Mark,
attrs: ProsemirrorAttributes,
): boolean {
const currentAttrs = nodeOrMark.attrs ?? {};
return Object.entries(attrs).every(([name, value]) => currentAttrs[name] === value);
}