diff --git a/PolymorphicBlocks b/PolymorphicBlocks index 274f953d..9e176023 160000 --- a/PolymorphicBlocks +++ b/PolymorphicBlocks @@ -1 +1 @@ -Subproject commit 274f953defb1e42754a5724c0aaab51e4f21d91d +Subproject commit 9e17602308fcd800e19c29b3125bbe62b0310a17 diff --git a/src/main/scala/edg_ide/edgir_graph/EdgirGraph.scala b/src/main/scala/edg_ide/edgir_graph/EdgirGraph.scala index a904e726..1c33a1b4 100644 --- a/src/main/scala/edg_ide/edgir_graph/EdgirGraph.scala +++ b/src/main/scala/edg_ide/edgir_graph/EdgirGraph.scala @@ -84,29 +84,51 @@ object EdgirGraph { ): Seq[EdgirEdge] = { constraints.flatMap { case (name, constr) => constr.expr match { - case expr.ValueExpr.Expr.Connected(connect) => - // in the loading pass, the source is the block side and the target is the link side - Some( - EdgirEdge( - ConnectWrapper(path + name, constr), - source = Ref.unapply(connect.getBlockPort.getRef).get.slice(0, 2), // only block and port, ignore arrays - target = Ref.unapply(connect.getLinkPort.getRef).get.slice(0, 2) - ) - ) - case expr.ValueExpr.Expr.Exported(export) => - // in the loading pass, the source is the block side and the target is the external port - Some( - EdgirEdge( - ConnectWrapper(path + name, constr), - source = Ref.unapply(export.getInternalBlockPort.getRef).get.slice(0, 2), - target = Ref.unapply(`export`.getExteriorPort.getRef).get.slice(0, 1) - ) - ) - case _ => None + case expr.ValueExpr.Expr.Connected(connected) => + connectedToEdge(path, name, constr, connected) + case expr.ValueExpr.Expr.Exported(exported) => + exportedToEdge(path, name, constr, exported) + case expr.ValueExpr.Expr.ConnectedArray(connectedArray) => + connectedArray.expanded.flatMap(connectedToEdge(path, name, constr, _)) + case expr.ValueExpr.Expr.ExportedArray(exportedArray) => + exportedArray.expanded.flatMap(exportedToEdge(path, name, constr, _)) + case _ => Seq() } }.toSeq } + protected def connectedToEdge( + path: DesignPath, + constrName: String, + constr: expr.ValueExpr, + connected: expr.ConnectedExpr + ): Seq[EdgirEdge] = connected.expanded match { + case Seq() => Seq( // in the loading pass, the source is the block side and the target is the link side + EdgirEdge( + ConnectWrapper(path + constrName, constr), + source = Ref.unapply(connected.getBlockPort.getRef).get.slice(0, 2), // only block and port, ignore arrays + target = Ref.unapply(connected.getLinkPort.getRef).get.slice(0, 2) + )) + case Seq(expanded) => connectedToEdge(path, constrName, constr, expanded) + case _ => throw new IllegalArgumentException("unexpected multiple expanded") + } + + protected def exportedToEdge( + path: DesignPath, + constrName: String, + constr: expr.ValueExpr, + exported: expr.ExportedExpr + ): Seq[EdgirEdge] = exported.expanded match { + case Seq() => Seq( // in the loading pass, the source is the block side and the target is the external port + EdgirEdge( + ConnectWrapper(path + constrName, constr), + source = Ref.unapply(exported.getInternalBlockPort.getRef).get.slice(0, 2), + target = Ref.unapply(exported.getExteriorPort.getRef).get.slice(0, 1) + )) + case Seq(expanded) => exportedToEdge(path, constrName, constr, expanded) + case _ => throw new IllegalArgumentException("unexpected multiple expanded") + } + def blockLikeToNode(path: DesignPath, blockLike: elem.BlockLike): EdgirNode = { blockLike.`type` match { case elem.BlockLike.Type.Hierarchy(block) => diff --git a/src/main/scala/edg_ide/psi_edits/InsertAction.scala b/src/main/scala/edg_ide/psi_edits/InsertAction.scala index 1c89b0b9..5ab8d86d 100644 --- a/src/main/scala/edg_ide/psi_edits/InsertAction.scala +++ b/src/main/scala/edg_ide/psi_edits/InsertAction.scala @@ -44,6 +44,7 @@ object InsertAction { // given the leaf element at the caret, returns the rootmost element right before the caret def prevElementOf(element: PsiElement): PsiElement = { + requireExcept(element.getTextRange != null, "element with null range") // if traversing beyond file level if (element.getTextRange.getStartOffset == caretOffset) { // caret at beginning of element, so take the previous val prev = PsiTreeUtil.prevLeaf(element) prevElementOf(prev.exceptNull("no element before caret")) diff --git a/src/main/scala/edg_ide/psi_edits/InsertBlockAction.scala b/src/main/scala/edg_ide/psi_edits/InsertBlockAction.scala index d0a3e105..c84299ea 100644 --- a/src/main/scala/edg_ide/psi_edits/InsertBlockAction.scala +++ b/src/main/scala/edg_ide/psi_edits/InsertBlockAction.scala @@ -1,16 +1,15 @@ package edg_ide.psi_edits -import com.intellij.codeInsight.template.impl.TemplateState import com.intellij.openapi.application.{ModalityState, ReadAction} import com.intellij.openapi.command.WriteCommandAction.writeCommandAction import com.intellij.openapi.project.Project +import com.intellij.psi.PsiElement import com.intellij.psi.util.PsiTreeUtil -import com.intellij.psi.{PsiElement, PsiWhiteSpace} import com.intellij.util.concurrency.AppExecutorUtil import com.jetbrains.python.psi._ import edg.util.Errorable -import edg_ide.util.ExceptionNotifyImplicits.{ExceptNotify, ExceptOption, ExceptSeq} -import edg_ide.util.{DesignAnalysisUtils, exceptable, requireExcept} +import edg_ide.util.ExceptionNotifyImplicits.{ExceptNotify, ExceptSeq} +import edg_ide.util.{DesignAnalysisUtils, exceptable} import java.util.concurrent.Callable @@ -111,149 +110,4 @@ object InsertBlockAction { () => insertBlockFlow } - /** Creates an action to start a live template to insert a block. - */ - def createTemplateBlock( - contextClass: PyClass, - libClass: PyClass, - actionName: String, - project: Project, - continuation: (String, PsiElement) => Unit - ): Errorable[() => Unit] = exceptable { - val languageLevel = LanguageLevel.forElement(libClass) - val psiElementGenerator = PyElementGenerator.getInstance(project) - - // given some caret position, returns the best insertion position - def getInsertionElt(caretEltOpt: Option[PsiElement]): PsiElement = { - exceptable { // TODO better propagation of error messages - val caretElt = caretEltOpt.exceptNone("no elt at caret") - val caretStatement = InsertAction.snapInsertionEltOfType[PyStatement](caretElt).get - val containingPsiFn = PsiTreeUtil - .getParentOfType(caretStatement, classOf[PyFunction]) - .exceptNull(s"caret not in a function") - val containingPsiClass = PsiTreeUtil - .getParentOfType(containingPsiFn, classOf[PyClass]) - .exceptNull(s"caret not in a class") - requireExcept(containingPsiClass == contextClass, s"caret not in class of type ${libClass.getName}") - caretStatement - }.toOption.orElse { - val candidates = - InsertAction.findInsertionElements(contextClass, InsertBlockAction.VALID_FUNCTION_NAMES) - candidates.headOption - }.get // TODO insert contents() if needed - } - - val movableLiveTemplate = new MovableLiveTemplate(actionName) { - override def startTemplate(caretEltOpt: Option[PsiElement]): InsertionLiveTemplate = { - val insertAfter = getInsertionElt(caretEltOpt) - val containingPsiFn = PsiTreeUtil.getParentOfType(insertAfter, classOf[PyFunction]) - val containingPsiClass = PsiTreeUtil.getParentOfType(containingPsiFn, classOf[PyClass]) - val selfName = containingPsiFn.getParameterList.getParameters.toSeq - .exceptEmpty(s"function ${containingPsiFn.getName} has no self") - .head - .getName - - val newAssignTemplate = psiElementGenerator.createFromText( - languageLevel, - classOf[PyAssignmentStatement], - s"$selfName.name = $selfName.Block(${libClass.getName}())" - ) - val containingStmtList = PsiTreeUtil.getParentOfType(insertAfter, classOf[PyStatementList]) - - val newAssign = - containingStmtList.addAfter(newAssignTemplate, insertAfter).asInstanceOf[PyAssignmentStatement] - val newArgList = newAssign.getAssignedValue - .asInstanceOf[PyCallExpression] - .getArgument(0, classOf[PyCallExpression]) - .getArgumentList - - val initParams = - DesignAnalysisUtils.initParamsOf(libClass, project).toOption.getOrElse((Seq(), Seq())) - val allParams = initParams._1 ++ initParams._2 - - val nameTemplateVar = new InsertionLiveTemplate.Reference( - "name", - newAssign.getTargets.head.asInstanceOf[PyTargetExpression], - InsertionLiveTemplate.validatePythonName(_, _, Some(containingPsiClass)), - defaultValue = Some("") - ) - - val argTemplateVars = allParams.map { initParam => - val paramName = initParam.getName() + (Option(initParam.getAnnotationValue) match { - case Some(typed) => f": $typed" - case None => "" - }) - - if (initParam.getDefaultValue == null) { // required argument, needs ellipsis - newArgList.addArgument(psiElementGenerator.createEllipsis()) - new InsertionLiveTemplate.Variable(paramName, newArgList.getArguments.last) - } else { // optional argument - // ellipsis is generated in the AST to give the thing a handle, the template replaces it with an empty - newArgList.addArgument( - psiElementGenerator.createKeywordArgument(languageLevel, initParam.getName, "...") - ) - new InsertionLiveTemplate.Variable( - f"$paramName (optional)", - newArgList.getArguments.last.asInstanceOf[PyKeywordArgument].getValueExpression, - defaultValue = Some("") - ) - } - } - - new InsertionLiveTemplate(newAssign, IndexedSeq(nameTemplateVar) ++ argTemplateVars) - } - } - - movableLiveTemplate.addTemplateStateListener(new TemplateFinishedListener { - override def templateFinished(state: TemplateState, brokenOff: Boolean): Unit = { - val expr = state.getExpressionContextForSegment(0) - if (expr.getTemplateEndOffset <= expr.getTemplateStartOffset) { - return // ignored if template was deleted, including through moving the template - } - - val insertedName = state.getVariableValue("name").getText - if (insertedName.isEmpty && brokenOff) { // canceled by esc while name is empty - writeCommandAction(project) - .withName(s"cancel $actionName") - .compute(() => { - TemplateUtils.deleteTemplate(state) - }) - } else { - var templateElem = state.getExpressionContextForSegment(0).getPsiElementAtStartOffset - while (templateElem.isInstanceOf[PsiWhiteSpace]) { // ignore inserted whitespace before the statement - templateElem = templateElem.getNextSibling - } - try { - val args = templateElem - .asInstanceOf[PyAssignmentStatement] - .getAssignedValue - .asInstanceOf[PyCallExpression] // the self.Block(...) call - .getArgument(0, classOf[PyCallExpression]) // the object instantiation - .getArgumentList // args to the object instantiation - val deleteArgs = args.getArguments.flatMap { // remove empty kwargs - case arg: PyKeywordArgument => if (arg.getValueExpression == null) Some(arg) else None - case _ => None // ignored - } - writeCommandAction(project) - .withName(s"clean $actionName") - .compute(() => { - deleteArgs.foreach(_.delete()) - }) - } finally { - continuation(insertedName, templateElem) - } - } - } - }) - - val caretElt = InsertAction.getCaretForNewClassStatement(contextClass, project).toOption - def insertBlockFlow: Unit = { - writeCommandAction(project) - .withName(s"$actionName") - .compute(() => { - movableLiveTemplate.run(caretElt) - }) - } - () => insertBlockFlow - } } diff --git a/src/main/scala/edg_ide/psi_edits/InsertionLiveTemplate.scala b/src/main/scala/edg_ide/psi_edits/InsertionLiveTemplate.scala index 356909ed..b9792bbd 100644 --- a/src/main/scala/edg_ide/psi_edits/InsertionLiveTemplate.scala +++ b/src/main/scala/edg_ide/psi_edits/InsertionLiveTemplate.scala @@ -247,9 +247,10 @@ class InsertionLiveTemplate(elt: PsiElement, variables: IndexedSeq[InsertionLive // if the editor just started, it isn't marked as showing and the tooltip creation crashes // TODO the positioning is still off, but at least it doesn't crash UIUtil.markAsShowing(editor.getContentComponent, true) + val firstVariableName = variables.headOption.map(_.name).getOrElse("") val tooltipString = initialTooltip match { - case Some(initialTooltip) => f"${variables.head.name} | $initialTooltip" - case None => f"${variables.head.name}" + case Some(initialTooltip) => f"$firstVariableName | $initialTooltip" + case None => firstVariableName } val tooltip = createTemplateTooltip(tooltipString, editor) diff --git a/src/main/scala/edg_ide/psi_edits/LiveTemplateConnect.scala b/src/main/scala/edg_ide/psi_edits/LiveTemplateConnect.scala new file mode 100644 index 00000000..83b66b3c --- /dev/null +++ b/src/main/scala/edg_ide/psi_edits/LiveTemplateConnect.scala @@ -0,0 +1,220 @@ +package edg_ide.psi_edits + +import com.intellij.codeInsight.template.impl.TemplateState +import com.intellij.openapi.command.WriteCommandAction.writeCommandAction +import com.intellij.openapi.project.Project +import com.intellij.psi.util.PsiTreeUtil +import com.intellij.psi.{PsiElement, PsiWhiteSpace} +import com.jetbrains.python.psi._ +import edg.util.Errorable +import edg_ide.util.ExceptionNotifyImplicits.{ExceptOption, ExceptSeq} +import edg_ide.util.{DesignAnalysisUtils, PortConnects, exceptable} + +object LiveTemplateConnect { + protected val kTemplateVariableName = "name (optional)" + + // generate the reference Python HDL code for a PortConnect + protected def connectedToRefCode(selfName: String, connect: PortConnects.Base): String = connect match { + case PortConnects.BlockPort(blockName, portName) => s"$selfName.$blockName.$portName" + case PortConnects.BoundaryPort(portName, _) => s"$selfName.$portName" + case PortConnects.BlockVectorUnit(blockName, portName) => s"$selfName.$blockName.$portName" + case PortConnects.BlockVectorSlicePort(blockName, portName, suggestedIndex) => suggestedIndex match { + case Some(suggestedIndex) => s"$selfName.$blockName.$portName.request('$suggestedIndex')" + case None => s"$selfName.$blockName.$portName.request()" + } + case PortConnects.BlockVectorSliceVector(blockName, portName, suggestedIndex) => suggestedIndex match { + case Some(suggestedIndex) => s"$selfName.$blockName.$portName.request_vector('$suggestedIndex')" + case None => s"$selfName.$blockName.$portName.request_vector()" + } + case PortConnects.BlockVectorSlice(blockName, portName, suggestedIndex) => + throw new IllegalArgumentException() + case PortConnects.BoundaryPortVectorUnit(portName) => s"$selfName.$portName" + } + + // gets the class member variable (if any) for a PortConnect + protected def connectedToRequiredAttr(connect: PortConnects.Base): Option[String] = connect match { + // for blocks that are part of an ElementDict (by heuristic), only take the non-index part + case PortConnects.BlockPort(blockName, portName) => Some(blockName.takeWhile(_ != '[')) + case PortConnects.BoundaryPort(portName, _) => Some(portName) + case PortConnects.BlockVectorUnit(blockName, portName) => Some(blockName.takeWhile(_ != '[')) + case PortConnects.BlockVectorSlicePort(blockName, portName, _) => Some(blockName.takeWhile(_ != '[')) + case PortConnects.BlockVectorSliceVector(blockName, portName, _) => Some(blockName.takeWhile(_ != '[')) + case PortConnects.BlockVectorSlice(blockName, portName, _) => Some(blockName.takeWhile(_ != '[')) + case PortConnects.BoundaryPortVectorUnit(portName) => Some(portName) + } + + // Creates an action to start a live template to insert the connection + def createTemplateConnect( + contextClass: PyClass, + actionName: String, + project: Project, + startingConnect: PortConnects.Base, + newConnects: Seq[PortConnects.Base], + continuation: (Option[String], PsiElement) => Unit, + ): Errorable[() => Unit] = exceptable { + val languageLevel = LanguageLevel.forElement(contextClass) + val psiElementGenerator = PyElementGenerator.getInstance(project) + val allConnects = startingConnect +: newConnects + + val movableLiveTemplate = new MovableLiveTemplate(actionName) { + // starts the live template insertion as a statement + protected def startStatementInsertionTemplate( + insertAfter: PsiElement + ): InsertionLiveTemplate = { + val containingPsiFn = PsiTreeUtil.getParentOfType(insertAfter, classOf[PyFunction]) + val selfName = containingPsiFn.getParameterList.getParameters.toSeq + .headOption.exceptNone(s"function ${containingPsiFn.getName} has no self") + .getName + val containingStmtList = PsiTreeUtil.getParentOfType(insertAfter, classOf[PyStatementList]) + + // TODO don't allow name when appending to a named connect (?) + val newConnectTemplate = psiElementGenerator.createFromText( + languageLevel, + classOf[PyAssignmentStatement], + s"$selfName.name = $selfName.connect()" + ) + val newConnect = + containingStmtList.addAfter(newConnectTemplate, insertAfter).asInstanceOf[PyAssignmentStatement] + val newConnectArgs = newConnect.getAssignedValue + .asInstanceOf[PyCallExpression] + .getArgumentList + + allConnects.foreach { newConnect => + newConnectArgs.addArgument(psiElementGenerator.createExpressionFromText( + languageLevel, + connectedToRefCode(selfName, newConnect) + )) + } + + val nameTemplateVar = new InsertionLiveTemplate.Reference( + kTemplateVariableName, + newConnect.getTargets.head.asInstanceOf[PyTargetExpression], + defaultValue = Some("") + ) + + new InsertionLiveTemplate(newConnect, IndexedSeq(nameTemplateVar)) + } + + // if caretElt is in a PyCallExpression that is a connect involving any of the ports in this connection, + // return the InsertionLiveTemplate, otherwise None + protected def tryStartStatementInsertionTemplate( + caretElt: PsiElement, + earliestPosition: Option[PsiElement] + ): Option[InsertionLiveTemplate] = { + val callCandidate = PsiTreeUtil.getParentOfType(caretElt, classOf[PyCallExpression]) + if (callCandidate == null) return None + + val containingPsiFn = PsiTreeUtil.getParentOfType(caretElt, classOf[PyFunction]) + val selfName = containingPsiFn.getParameterList.getParameters.toSeq + .headOption.getOrElse(return None) + .getName + val connectReference = psiElementGenerator.createExpressionFromText(languageLevel, s"$selfName.connect") + if (!callCandidate.getCallee.textMatches(connectReference)) return None + + val isAfterEarliest = earliestPosition.flatMap( // aka is in valid position + DesignAnalysisUtils.elementAfterEdg(_, callCandidate, project)).getOrElse(true) + if (!isAfterEarliest) return None + + val matchingConnects = callCandidate.getArgumentList.getArguments.flatMap(arg => + allConnects.flatMap { connect => + if (arg.textMatches(connectedToRefCode(selfName, connect))) { + Some(connect) + } else { + None + } + } + ) + if (matchingConnects.isEmpty) return None + + // validation complete, start the template + val connectsToAdd = allConnects.filter(!matchingConnects.contains(_)) + connectsToAdd.foreach { newConnect => + callCandidate.getArgumentList.addArgument(psiElementGenerator.createExpressionFromText( + languageLevel, + connectedToRefCode(selfName, newConnect) + )) + } + + Some(new InsertionLiveTemplate(callCandidate, IndexedSeq())) + } + + // TODO startTemplate should be able to fail - Errorable + override def startTemplate(caretEltOpt: Option[PsiElement]): InsertionLiveTemplate = { + // find earliest insertion position (after all refs are defined) + val allRequiredAttrs = allConnects.flatMap(connectedToRequiredAttr) + val earliestPosition = TemplateUtils.getLastAttributeAssignment(contextClass, allRequiredAttrs, project) + + // TODO live template insertion needs to support modifying AST nodes (instead of only addition) +// caretEltOpt.foreach { caretElt => // check if caret is in a connect +// tryStartStatementInsertionTemplate(caretElt, earliestPosition).foreach { +// return _ +// } +// } // otherwise continue to stmt insertion + + val validCaretEltOpt = caretEltOpt.flatMap(TemplateUtils.getInsertionStmt(_, contextClass)) + val preInsertAfter = validCaretEltOpt + .getOrElse(InsertAction.findInsertionElements(contextClass, InsertBlockAction.VALID_FUNCTION_NAMES).head) + + // adjust insertion position to be after all assignments to required references + val insertAfter = earliestPosition.map { earliestPosition => + if (!DesignAnalysisUtils.elementAfterEdg(preInsertAfter, earliestPosition, project).getOrElse(true)) { + preInsertAfter + } else { + earliestPosition + } + }.getOrElse(preInsertAfter) + startStatementInsertionTemplate(insertAfter) + } + } + + movableLiveTemplate.addTemplateStateListener(new TemplateFinishedListener { + override def templateFinished(state: TemplateState, brokenOff: Boolean): Unit = { + val expr = state.getExpressionContextForSegment(0) + if (expr.getTemplateEndOffset <= expr.getTemplateStartOffset) { + return // ignored if template was deleted, including through moving the template + } + + val nameVar = Option(state.getVariableValue(kTemplateVariableName)) + val insertedNameOpt = nameVar.map(_.getText).filter(_.nonEmpty) + if (insertedNameOpt.isEmpty && brokenOff) { // canceled by esc while name is empty + writeCommandAction(project) + .withName(s"cancel $actionName") + .compute(() => { + TemplateUtils.deleteTemplate(state) + }) + } else { // commit + var templateElem = state.getExpressionContextForSegment(0).getPsiElementAtStartOffset + while (templateElem.isInstanceOf[PsiWhiteSpace]) { // ignore inserted whitespace before the statement + templateElem = templateElem.getNextSibling + } + + if (nameVar.nonEmpty && insertedNameOpt.isEmpty) { // if name not specified, make the connect anonymous + try { + val templateAssign = templateElem.asInstanceOf[PyAssignmentStatement] + writeCommandAction(project) + .withName(s"clean $actionName") + .compute(() => { + templateElem.replace(templateAssign.getAssignedValue) + }) + } catch { + case _: Throwable => // ignore + } + } + + continuation(insertedNameOpt, templateElem) + } + } + }) + + val caretElt = InsertAction.getCaretForNewClassStatement(contextClass, project).toOption + def insertBlockFlow: Unit = { + writeCommandAction(project) + .withName(s"$actionName") + .compute(() => { + movableLiveTemplate.run(caretElt) + }) + } + + () => insertBlockFlow + } +} diff --git a/src/main/scala/edg_ide/psi_edits/LiveTemplateInsertBlock.scala b/src/main/scala/edg_ide/psi_edits/LiveTemplateInsertBlock.scala new file mode 100644 index 00000000..c0cc5f30 --- /dev/null +++ b/src/main/scala/edg_ide/psi_edits/LiveTemplateInsertBlock.scala @@ -0,0 +1,145 @@ +package edg_ide.psi_edits + +import com.intellij.codeInsight.template.impl.TemplateState +import com.intellij.openapi.command.WriteCommandAction.writeCommandAction +import com.intellij.openapi.project.Project +import com.intellij.psi.util.PsiTreeUtil +import com.intellij.psi.{PsiElement, PsiWhiteSpace} +import com.jetbrains.python.psi._ +import edg.util.Errorable +import edg_ide.util.ExceptionNotifyImplicits.ExceptSeq +import edg_ide.util.{DesignAnalysisUtils, exceptable} + +object LiveTemplateInsertBlock { + + /** Creates an action to start a live template to insert a block. + */ + def createTemplateBlock( + contextClass: PyClass, + libClass: PyClass, + actionName: String, + project: Project, + continuation: (String, PsiElement) => Unit + ): Errorable[() => Unit] = exceptable { + val languageLevel = LanguageLevel.forElement(libClass) + val psiElementGenerator = PyElementGenerator.getInstance(project) + + val movableLiveTemplate = new MovableLiveTemplate(actionName) { + // TODO startTemplate should be able to fail - Errorable + override def startTemplate(caretEltOpt: Option[PsiElement]): InsertionLiveTemplate = { + val insertAfter = caretEltOpt.flatMap(TemplateUtils.getInsertionStmt(_, contextClass)) + .getOrElse(InsertAction.findInsertionElements(contextClass, InsertBlockAction.VALID_FUNCTION_NAMES).head) + val containingPsiFn = PsiTreeUtil.getParentOfType(insertAfter, classOf[PyFunction]) + val containingPsiClass = PsiTreeUtil.getParentOfType(containingPsiFn, classOf[PyClass]) + val selfName = containingPsiFn.getParameterList.getParameters.toSeq + .exceptEmpty(s"function ${containingPsiFn.getName} has no self") + .head + .getName + + val newAssignTemplate = psiElementGenerator.createFromText( + languageLevel, + classOf[PyAssignmentStatement], + s"$selfName.name = $selfName.Block(${libClass.getName}())" + ) + val containingStmtList = PsiTreeUtil.getParentOfType(insertAfter, classOf[PyStatementList]) + + val newAssign = + containingStmtList.addAfter(newAssignTemplate, insertAfter).asInstanceOf[PyAssignmentStatement] + val newArgList = newAssign.getAssignedValue + .asInstanceOf[PyCallExpression] + .getArgument(0, classOf[PyCallExpression]) + .getArgumentList + + val initParams = + DesignAnalysisUtils.initParamsOf(libClass, project).toOption.getOrElse((Seq(), Seq())) + val allParams = initParams._1 ++ initParams._2 + + val nameTemplateVar = new InsertionLiveTemplate.Reference( + "name", + newAssign.getTargets.head.asInstanceOf[PyTargetExpression], + InsertionLiveTemplate.validatePythonName(_, _, Some(containingPsiClass)), + defaultValue = Some("") + ) + + val argTemplateVars = allParams.map { initParam => + val paramName = initParam.getName() + (Option(initParam.getAnnotationValue) match { + case Some(typed) => f": $typed" + case None => "" + }) + + if (initParam.getDefaultValue == null) { // required argument, needs ellipsis + newArgList.addArgument(psiElementGenerator.createEllipsis()) + new InsertionLiveTemplate.Variable(paramName, newArgList.getArguments.last) + } else { // optional argument + // ellipsis is generated in the AST to give the thing a handle, the template replaces it with an empty + newArgList.addArgument( + psiElementGenerator.createKeywordArgument(languageLevel, initParam.getName, "...") + ) + new InsertionLiveTemplate.Variable( + f"$paramName (optional)", + newArgList.getArguments.last.asInstanceOf[PyKeywordArgument].getValueExpression, + defaultValue = Some("") + ) + } + } + + new InsertionLiveTemplate(newAssign, IndexedSeq(nameTemplateVar) ++ argTemplateVars) + } + } + + movableLiveTemplate.addTemplateStateListener(new TemplateFinishedListener { + override def templateFinished(state: TemplateState, brokenOff: Boolean): Unit = { + val expr = state.getExpressionContextForSegment(0) + if (expr.getTemplateEndOffset <= expr.getTemplateStartOffset) { + return // ignored if template was deleted, including through moving the template + } + + val insertedName = state.getVariableValue("name").getText + if (insertedName.isEmpty && brokenOff) { // canceled by esc while name is empty + writeCommandAction(project) + .withName(s"cancel $actionName") + .compute(() => { + TemplateUtils.deleteTemplate(state) + }) + } else { + var templateElem = state.getExpressionContextForSegment(0).getPsiElementAtStartOffset + while (templateElem.isInstanceOf[PsiWhiteSpace]) { // ignore inserted whitespace before the statement + templateElem = templateElem.getNextSibling + } + try { + val args = templateElem + .asInstanceOf[PyAssignmentStatement] + .getAssignedValue + .asInstanceOf[PyCallExpression] // the self.Block(...) call + .getArgument(0, classOf[PyCallExpression]) // the object instantiation + .getArgumentList // args to the object instantiation + val deleteArgs = args.getArguments.flatMap { // remove empty kwargs + case arg: PyKeywordArgument => if (arg.getValueExpression == null) Some(arg) else None + case _ => None // ignored + } + writeCommandAction(project) + .withName(s"clean $actionName") + .compute(() => { + deleteArgs.foreach(_.delete()) + }) + } catch { + case _: Throwable => // ignore + } + continuation(insertedName, templateElem) + } + } + }) + + val caretElt = InsertAction.getCaretForNewClassStatement(contextClass, project).toOption + + def insertBlockFlow: Unit = { + writeCommandAction(project) + .withName(s"$actionName") + .compute(() => { + movableLiveTemplate.run(caretElt) + }) + } + + () => insertBlockFlow + } +} diff --git a/src/main/scala/edg_ide/psi_edits/TemplateUtils.scala b/src/main/scala/edg_ide/psi_edits/TemplateUtils.scala index 63d08c87..8e557491 100644 --- a/src/main/scala/edg_ide/psi_edits/TemplateUtils.scala +++ b/src/main/scala/edg_ide/psi_edits/TemplateUtils.scala @@ -2,7 +2,12 @@ package edg_ide.psi_edits import com.intellij.codeInsight.template.{Template, TemplateEditingAdapter} import com.intellij.codeInsight.template.impl.TemplateState +import com.intellij.openapi.project.Project import com.intellij.psi.PsiElement +import com.intellij.psi.util.PsiTreeUtil +import com.jetbrains.python.psi.{PyClass, PyFunction, PyStatement} +import edg_ide.util.ExceptionNotifyImplicits.ExceptNotify +import edg_ide.util.{DesignAnalysisUtils, requireExcept} import scala.collection.mutable @@ -35,6 +40,28 @@ object TemplateUtils { } templateState.update() // update to end the template } + + // given some caret position, returns the top-level statement if it's valid for statement insertion + def getInsertionStmt(caretElt: PsiElement, requiredContextClass: PyClass): Option[PyStatement] = { + val caretStatement = InsertAction.snapInsertionEltOfType[PyStatement](caretElt).get + val containingPsiFn = PsiTreeUtil + .getParentOfType(caretStatement, classOf[PyFunction]) + if (containingPsiFn == null) return None + val containingPsiClass = PsiTreeUtil + .getParentOfType(containingPsiFn, classOf[PyClass]) + if (containingPsiClass == null) return None + if (containingPsiClass != requiredContextClass) return None + Some(caretStatement) + } + + // given a class and a list of attributes, returns the last attribute assignment, if any + def getLastAttributeAssignment(psiClass: PyClass, attrs: Seq[String], project: Project): Option[PyStatement] = { + attrs.flatMap { attr => + DesignAnalysisUtils.findAssignmentsTo(psiClass, attr, project) + }.sortWith { case (a, b) => + DesignAnalysisUtils.elementAfterEdg(a, b, project).getOrElse(false) + }.lastOption + } } /** Utility on top of TemplateEditingAdapter that provides a templateFinished that provides both the TemplateState (note diff --git a/src/main/scala/edg_ide/swing/blocks/JBlockDiagramVisualizer.scala b/src/main/scala/edg_ide/swing/blocks/JBlockDiagramVisualizer.scala index 50755987..ff744131 100644 --- a/src/main/scala/edg_ide/swing/blocks/JBlockDiagramVisualizer.scala +++ b/src/main/scala/edg_ide/swing/blocks/JBlockDiagramVisualizer.scala @@ -195,22 +195,23 @@ class JBlockDiagramVisualizer(var rootNode: ElkNode, var showTop: Boolean = fals val elementGraphics = mouseOverElts.map { elt => elt -> mouseoverModifier } ++ errorElts.map { elt => elt -> errorModifier } ++ selected.map { elt => elt -> selectedModifier } ++ - unselectable.map { elt => elt -> dimGraphics } ++ staleElts.map { elt => elt -> staleModifier } + // unselectable handled separately - when highlighting is active, it is treated as dimmed with the rest val backgroundPaintGraphics = paintGraphics.create().asInstanceOf[Graphics2D] backgroundPaintGraphics.setBackground(this.getBackground) val painter = highlighted match { case None => // normal rendering - new StubEdgeElkNodePainter(rootNode, showTop, zoomLevel, elementGraphics = elementGraphics) + val innerElementGraphics = elementGraphics ++ + unselectable.map { elt => elt -> dimGraphics } + new StubEdgeElkNodePainter(rootNode, showTop, zoomLevel, elementGraphics = innerElementGraphics) case Some(highlighted) => // default dim rendering - val highlightedElementGraphics = elementGraphics ++ - highlighted.map { elt => - elt -> ElementGraphicsModifier( // undo the dim rendering for highlighted - strokeGraphics = ElementGraphicsModifier.withColor(getForeground), - textGraphics = ElementGraphicsModifier.withColor(getForeground) - ) - } + val highlightedElementGraphics = (highlighted -- unselectable).toSeq.map { elt => + elt -> ElementGraphicsModifier( // undo the dim rendering for highlighted + strokeGraphics = ElementGraphicsModifier.withColor(getForeground), + textGraphics = ElementGraphicsModifier.withColor(getForeground) + ) + } ++ elementGraphics new StubEdgeElkNodePainter( rootNode, showTop, diff --git a/src/main/scala/edg_ide/ui/BlockVisualizerPanel.scala b/src/main/scala/edg_ide/ui/BlockVisualizerPanel.scala index c6aab056..7f441e03 100644 --- a/src/main/scala/edg_ide/ui/BlockVisualizerPanel.scala +++ b/src/main/scala/edg_ide/ui/BlockVisualizerPanel.scala @@ -27,7 +27,7 @@ import edgrpc.hdl.{hdl => edgrpc} import org.eclipse.elk.graph.{ElkGraphElement, ElkNode} import java.awt.datatransfer.DataFlavor -import java.awt.event.{ComponentAdapter, ComponentEvent, MouseAdapter, MouseEvent} +import java.awt.event.{ComponentAdapter, ComponentEvent, KeyAdapter, KeyEvent, MouseAdapter, MouseEvent} import java.awt.{BorderLayout, GridBagConstraints, GridBagLayout} import java.io.{File, FileInputStream} import java.util.concurrent.{Callable, TimeUnit} @@ -215,6 +215,12 @@ class BlockVisualizerPanel(val project: Project, toolWindow: ToolWindow) extends } } } + graph.addKeyListener(new KeyAdapter { + override def keyPressed(e: KeyEvent): Unit = { + activeTool.onKeyPress(e) + } + }) + private val centeringGraph = new JPanel(new GridBagLayout) centeringGraph.add(graph, new GridBagConstraints()) @@ -412,6 +418,9 @@ class BlockVisualizerPanel(val project: Project, toolWindow: ToolWindow) extends if (activeTool != defaultTool) { // revert to the default tool toolInterface.endTool() // TODO should we also preserve state like selected? } + + stalePaths.clear() + updateStale() } /** Updates the design tree only, where the overall "top design" does not change. Mainly used for speculative updates diff --git a/src/main/scala/edg_ide/ui/LibraryPanel.scala b/src/main/scala/edg_ide/ui/LibraryPanel.scala index a5a28682..667ba8c0 100644 --- a/src/main/scala/edg_ide/ui/LibraryPanel.scala +++ b/src/main/scala/edg_ide/ui/LibraryPanel.scala @@ -112,7 +112,7 @@ class LibraryBlockPopupMenu(blockType: ref.LibraryPath, project: Project) extend val (insertAction, insertItem) = if (EdgSettingsState.getInstance().useInsertionLiveTemplates) { val insertAction: Errorable[() => Unit] = exceptable { - InsertBlockAction + LiveTemplateInsertBlock .createTemplateBlock( contextPyClass.exceptError, blockPyClass.exceptError, diff --git a/src/main/scala/edg_ide/ui/tools/BaseTool.scala b/src/main/scala/edg_ide/ui/tools/BaseTool.scala index 0d8a6cd9..7cfedf08 100644 --- a/src/main/scala/edg_ide/ui/tools/BaseTool.scala +++ b/src/main/scala/edg_ide/ui/tools/BaseTool.scala @@ -4,7 +4,7 @@ import com.intellij.openapi.project.Project import edg.wir.{DesignPath, Library} import edgir.schema.schema -import java.awt.event.MouseEvent +import java.awt.event.{KeyEvent, MouseEvent} trait ToolInterface { // Returns the top-level visualization (focus / context) path @@ -47,4 +47,7 @@ trait BaseTool { // Mouse event that is generated on any mouse event in either the design tree or graph layout def onPathMouse(e: MouseEvent, path: DesignPath): Unit = {} + + // Keyboard event generated on any key even in the graph layout + def onKeyPress(e: KeyEvent): Unit = {} } diff --git a/src/main/scala/edg_ide/ui/tools/DefaultTool.scala b/src/main/scala/edg_ide/ui/tools/DefaultTool.scala index d2106229..dec6cdde 100644 --- a/src/main/scala/edg_ide/ui/tools/DefaultTool.scala +++ b/src/main/scala/edg_ide/ui/tools/DefaultTool.scala @@ -9,7 +9,7 @@ import edg.util.Errorable import edg.wir.DesignPath import edg.wir.ProtoUtil.ParamProtoToSeqMap import edg_ide.dse._ -import edg_ide.ui.{BlockVisualizerService, ContextMenuUtils, DseService, PopupUtils} +import edg_ide.ui.{BlockVisualizerService, ContextMenuUtils, DseService, EdgSettingsState, PopupUtils} import edg_ide.util.ExceptionNotifyImplicits.ExceptErrorable import edg_ide.util._ import edg_ide.{EdgirUtils, PsiUtils} @@ -278,7 +278,11 @@ class DesignPortPopupMenu(path: DesignPath, interface: ToolInterface) addSeparator() val startConnectAction = exceptable { - val connectTool = ConnectTool(interface, path).exceptError + val connectTool = if (EdgSettingsState.getInstance().useInsertionLiveTemplates) { + NewConnectTool(interface, path).exceptError + } else { + ConnectTool(interface, path).exceptError + } () => interface.startNewTool(connectTool) } private val startConnectItem = ContextMenuUtils.MenuItemFromErrorable(startConnectAction, "Start Connect") diff --git a/src/main/scala/edg_ide/ui/tools/NewConnectTool.scala b/src/main/scala/edg_ide/ui/tools/NewConnectTool.scala new file mode 100644 index 00000000..a17ff09f --- /dev/null +++ b/src/main/scala/edg_ide/ui/tools/NewConnectTool.scala @@ -0,0 +1,269 @@ +package edg_ide.ui.tools + +import com.intellij.openapi.diagnostic.Logger +import com.intellij.psi.PsiElement +import edg.util.Errorable +import edg.wir.{DesignPath, LibraryConnectivityAnalysis} +import edg_ide.EdgirUtils +import edg_ide.psi_edits.LiveTemplateConnect +import edg_ide.ui.{BlockVisualizerService, PopupUtils} +import edg_ide.util.ExceptionNotifyImplicits.{ExceptErrorable, ExceptNotify, ExceptOption, ExceptSeq} +import edg_ide.util.{ + BlockConnectedAnalysis, + ConnectBuilder, + DesignAnalysisUtils, + EdgirConnectExecutor, + PortConnectTyped, + PortConnects, + exceptable, + requireExcept +} +import edgir.elem.elem + +import java.awt.Component +import java.awt.event.{KeyEvent, MouseEvent} +import javax.swing.SwingUtilities +import scala.collection.mutable + +object NewConnectTool { + def apply(interface: ToolInterface, portPath: DesignPath): Errorable[NewConnectTool] = exceptable { + val focusPath = interface.getFocus + val focusBlock = EdgirUtils + .resolveExact(focusPath, interface.getDesign) + .exceptNone("can't reach focus block") + .instanceOfExcept[elem.HierarchyBlock]("focus block not a block") + + val portLink = { + val port = EdgirUtils.resolveExact(portPath, interface.getDesign).exceptNone("no port") + val portType = port match { + case port: elem.Port => port.getSelfClass + case port: elem.Bundle => port.getSelfClass + case array: elem.PortArray => array.getSelfClass + case _ => exceptable.fail("invalid port type") + } + val libraryAnalysis = new LibraryConnectivityAnalysis(interface.getLibrary) // TODO save and reuse? + val linkType = libraryAnalysis.linkOfPort(portType).exceptNone("no link type for port") + interface.getLibrary.getLink(linkType).exceptError + } + + val portRef = { // get selected port as Seq(...) reference + val (containingBlockPath, containingBlock) = EdgirUtils.resolveDeepestBlock(portPath, interface.getDesign) + val portRef = portPath.postfixFromOption(containingBlockPath).exceptNone("port not in focus block") + val portName = portRef.steps.headOption.exceptNone("port path empty").getName + + if (containingBlockPath == focusPath) { // boundary port + Seq(portName) + } else { // block port + val (blockParent, blockName) = containingBlockPath.split + requireExcept(blockParent == focusPath, "port not in focus block") + Seq(blockName, portName) + } + } + + val analysis = new BlockConnectedAnalysis(focusBlock) + val (portConnectName, portConnecteds, portConstrs) = { + // TODO findLast is a hack to prioritize the new vector slice (as opposed to prior connected ones) + analysis.connectedGroups.findLast { case (linkNameOpt, connecteds, constrs) => + connecteds.exists(_.connect.topPortRef == portRef) + }.exceptNone("no connection") + } + val portConnected = portConnecteds.filter(_.connect.topPortRef == portRef) + .onlyExcept("multiple connections") + var connectBuilder = ConnectBuilder(focusBlock, portLink, portConstrs) + .exceptNone("invalid connections to port") + + if (portConstrs.isEmpty) { // if no constraints (no prior link) start with the port itself + connectBuilder = connectBuilder.append(Seq(portConnected)) + .exceptNone("invalid connections to port") + } + + new NewConnectTool(interface, portConnectName, focusPath, portConnected, connectBuilder, analysis) + } +} + +class NewConnectTool( + val interface: ToolInterface, + linkNameOpt: Option[String], + containingBlockPath: DesignPath, + startingPort: PortConnectTyped[PortConnects.Base], + baseConnectBuilder: ConnectBuilder, // including startingPort, even if it's the only item (new link) + analysis: BlockConnectedAnalysis +) extends BaseTool { + private val logger = Logger.getInstance(this.getClass) + + val startingPortPath = containingBlockPath ++ startingPort.connect.topPortRef + + // individual ports selected by the user + var selectedConnects = mutable.ArrayBuffer[PortConnectTyped[PortConnects.Base]]() + // corresponding to selectedPorts, may have more ports from net joins + var currentConnectBuilder = baseConnectBuilder + + def getCurrentName(): String = { + if (selectedConnects.nonEmpty) { + val connectedPortNames = selectedConnects.map(_.connect.topPortRef.mkString(".")) + s"Connect ${connectedPortNames.mkString(", ")} to ${startingPort.connect.topPortRef.mkString(".")}" + } else { + s"Connect to ${startingPort.connect.topPortRef.mkString(".")}" + } + } + + def updateSelected(): Unit = { // updates selected in graph and text + if (selectedConnects.isEmpty) { + interface.setStatus("[Esc] cancel;" + getCurrentName()) + } else { + interface.setStatus("[Esc] cancel; [Enter/DblClick] complete; " + getCurrentName()) + } + + // mark all current selections + val connectedPorts = currentConnectBuilder.connected.map(containingBlockPath ++ _._1.connect.topPortRef) + interface.setGraphSelections(connectedPorts.toSet) + + // try all connections to determine additional possible connects + // note, vector slices may overlap and appear in multiple connect groups (and a new connection), + // but in this case it doesn't matter since this takes the most available port for marking highlights + val connectablePorts = mutable.ArrayBuffer[DesignPath]() + val connectableBlocks = mutable.ArrayBuffer[DesignPath]() + analysis.connectedGroups.foreach { case (linkNameOpt, connecteds, constrs) => + currentConnectBuilder.append(connecteds) match { + case Some(_) => connecteds.foreach { connected => + connected.connect.topPortRef match { // add containing block, if a block port + case Seq(blockName, portName) => connectableBlocks.append(containingBlockPath + blockName) + case _ => // ignored + } + connectablePorts.append(containingBlockPath ++ connected.connect.topPortRef) + } + case None => Seq() + } + } + + // enable selection of existing ports in connection (toggle-able) and connect-able ports + interface.setGraphHighlights( + Some((Seq(containingBlockPath) ++ connectedPorts ++ connectablePorts ++ connectableBlocks).toSet) + ) + } + + override def init(): Unit = { + updateSelected() + } + + def removeConnect(portPath: DesignPath): Unit = { + // remove all connections to the port path (should really only be one) + selectedConnects.filterInPlace(containingBlockPath ++ _.connect.topPortRef != portPath) + val selectedConnectConnects = selectedConnects.map(_.connect) + // recompute from scratch on removal, for simplicity + val allConnected = analysis.connectedGroups.filter { case (linkNameOpt, connecteds, constrs) => + connecteds.exists(connected => + selectedConnectConnects.contains(connected.connect) && + !((connected.connect.isInstanceOf[PortConnects.BlockVectorSlicePort] || + connected.connect.isInstanceOf[PortConnects.BlockVectorSliceVector]) && + connecteds.length > 1) // only keep individual BlockVectorSlice, TODO should not be hardcoded + ) + }.flatMap(_._2) + // update state + baseConnectBuilder.append(allConnected) match { + case Some(newConnectBuilder) => + currentConnectBuilder = newConnectBuilder + case None => // if the connect is invalid (shouldn't be possible), revert to the empty connect + logger.error(s"invalid connect from removal of $portPath") + currentConnectBuilder = baseConnectBuilder + selectedConnects.clear() + } + updateSelected() + } + + def addConnect(portPath: DesignPath): Unit = { + val newConnectedNet = analysis.connectedGroups.filter { case (linkNameOpt, connecteds, constrs) => + connecteds.exists(connected => + containingBlockPath ++ connected.connect.topPortRef == portPath && + !((connected.connect.isInstanceOf[PortConnects.BlockVectorSlicePort] || + connected.connect.isInstanceOf[PortConnects.BlockVectorSliceVector]) && + connecteds.length > 1) // only keep individual BlockVectorSlice, TODO should not be hardcoded + ) + }.flatMap(_._2) + val newConnected = // get single connected of this port + newConnectedNet.filter(containingBlockPath ++ _.connect.topPortRef == portPath) + val newConnectBuilder = currentConnectBuilder.append(newConnectedNet) + (newConnected, newConnectBuilder) match { + case (Seq(newConnected), Some(newConnectBuilder)) => // valid connect, commit and update UI + selectedConnects.append(newConnected) + currentConnectBuilder = newConnectBuilder + updateSelected() + case _ => + logger.warn(s"invalid connect from added port $portPath") // invalid connect, ignore + } + } + + protected def completeConnect(component: Component): Unit = { + if (selectedConnects.nonEmpty) { + val newConnects = selectedConnects.toSeq + val connectedBlockOpt = + EdgirConnectExecutor( + analysis.block, + linkNameOpt, + currentConnectBuilder, + startingPort, + newConnects + ) + val containerPyClassOpt = DesignAnalysisUtils.pyClassOf(analysis.block.getSelfClass, interface.getProject) + + (connectedBlockOpt, containerPyClassOpt.toOption) match { + case (Some(connectedBlock), Some(containerPyClass)) => + val continuation = (name: Option[String], inserted: PsiElement) => { + BlockVisualizerService(interface.getProject).visualizerPanelOption.foreach { + _.currentDesignModifyBlock(containingBlockPath)(_ => connectedBlock) + } + interface.endTool() + } + LiveTemplateConnect.createTemplateConnect( + containerPyClass, + getCurrentName(), + interface.getProject, + startingPort.connect, + newConnects.map(_.connect), + continuation + ).exceptError() + case _ => + if (connectedBlockOpt.isEmpty) { + logger.error(s"failed to create connected IR block") + } + containerPyClassOpt match { + case Errorable.Error(msg) => logger.error(s"failed to get container pyclass: $msg") + case _ => // ignored + } + PopupUtils.createErrorPopupAtMouse(s"internal error", component) + interface.endTool() + } + } else { // nothing to do, cancel + interface.endTool() + } + } + + override def onPathMouse(e: MouseEvent, path: DesignPath): Unit = { + val resolved = EdgirUtils.resolveExact(path, interface.getDesign) + + if (SwingUtilities.isLeftMouseButton(e) && e.getClickCount == 1) { // toggle selected port + val currentSelectedPorts = currentConnectBuilder.connected.map(containingBlockPath ++ _._1.connect.topPortRef) + resolved match { + case Some(_: elem.Port | _: elem.Bundle | _: elem.PortArray) => // toggle port + if (selectedConnects.exists(containingBlockPath ++ _.connect.topPortRef == path)) { // toggle de-select + removeConnect(path) + } else if (!currentSelectedPorts.contains(path) && path != startingPortPath) { // toggle select + addConnect(path) + } // otherwise unselectable port / block + case _ => // ignored + } + } else if (SwingUtilities.isLeftMouseButton(e) && e.getClickCount == 2) { // double-click finish shortcut + completeConnect(e.getComponent) + } + } + + override def onKeyPress(e: KeyEvent): Unit = { + if (e.getKeyCode == KeyEvent.VK_ESCAPE) { + interface.endTool() + e.consume() + } else if (e.getKeyCode == KeyEvent.VK_ENTER) { + completeConnect(e.getComponent) + e.consume() + } + } +} diff --git a/src/main/scala/edg_ide/util/BlockConnectedAnalysis.scala b/src/main/scala/edg_ide/util/BlockConnectedAnalysis.scala new file mode 100644 index 00000000..ea9ebc4e --- /dev/null +++ b/src/main/scala/edg_ide/util/BlockConnectedAnalysis.scala @@ -0,0 +1,123 @@ +package edg_ide.util +import edg.wir.ProtoUtil.{BlockProtoToSeqMap, ConstraintProtoToSeqMap, PortProtoToSeqMap} +import edgir.elem.elem +import edgir.expr.expr + +import scala.collection.mutable + +// provides link-level connectivity information (e.g. all connected ports in a link) for a block +class BlockConnectedAnalysis(val block: elem.HierarchyBlock) { + protected val connectionsBuilder = mutable.ArrayBuffer[( + Option[String], // link name, if part of a link + mutable.ArrayBuffer[PortConnectTyped[PortConnects.ConstraintBase]], + mutable.ArrayBuffer[expr.ValueExpr] + )]() + protected val linkNameToConnectionIndex = mutable.Map[String, Int]() // allows quick indexing + + // here, invalid constraints are silently discarded + block.constraints.toSeqMap.foreach { case (name, constr) => + val linkNameOpt = constr.expr match { // get the link name / builder map key, if it is a valid constraint + case expr.ValueExpr.Expr.Connected(connected) => + Some(connected.getLinkPort.getRef.steps.head.getName) + case expr.ValueExpr.Expr.ConnectedArray(connected) => + Some(connected.getLinkPort.getRef.steps.head.getName) + case expr.ValueExpr.Expr.Exported(exported) => None + case expr.ValueExpr.Expr.ExportedArray(exported) => None + case _ => None // ignored + } + val (connectedBuilder, constrBuilder) = linkNameOpt match { + case Some(linkName) => // is a link, need to fetch the existing link entry or add a new one + linkNameToConnectionIndex.get(linkName) match { + case Some(index) => + (connectionsBuilder(index)._2, connectionsBuilder(index)._3) + case None => + connectionsBuilder.append((Some(linkName), mutable.ArrayBuffer(), mutable.ArrayBuffer())) + linkNameToConnectionIndex(linkName) = connectionsBuilder.length - 1 + (connectionsBuilder.last._2, connectionsBuilder.last._3) + } + case None => // anonymous + connectionsBuilder.append((None, mutable.ArrayBuffer(), mutable.ArrayBuffer())) + (connectionsBuilder.last._2, connectionsBuilder.last._3) + } + + val connectedPortsOpt = PortConnects.fromConnect(constr).map { connecteds => + // silently discard non-found port + connecteds.flatMap(PortConnectTyped.fromConnect(_, block)) + } + connectedPortsOpt.foreach { connectedPorts => + connectedBuilder.addAll(connectedPorts) + } + constrBuilder.append(constr) + } + + protected val connectedsByPortRef = // all PortConnects by port ref, not including others in the connection + mutable.HashMap[Seq[String], mutable.ArrayBuffer[PortConnectTyped[PortConnects.ConstraintBase]]]() + connectionsBuilder.foreach { case (name, connecteds, constrs) => + connecteds.foreach { connected => + connectedsByPortRef.getOrElseUpdate(connected.connect.topPortRef, mutable.ArrayBuffer()).append(connected) + } + } + + protected val disconnectedBoundaryPortConnections = + mutable.ArrayBuffer[PortConnectTyped[PortConnects.ConstraintBase]]() + // TODO boundary ports, exports, and bridging currently not supported +// block.ports.toSeqMap.collect { +// case (portName, port) if !allConnectedPorts.contains(Seq(portName)) => +// val connectTypedOpt = port.is match { +// case elem.PortLike.Is.Port(port) => +// Some(PortConnectTyped(PortConnects.BoundaryPort(portName, Seq()), port.getSelfClass)) +// case elem.PortLike.Is.Bundle(port) => +// Some(PortConnectTyped(PortConnects.BoundaryPort(portName, Seq()), port.getSelfClass)) +// case elem.PortLike.Is.Array(array) => +// Some(PortConnectTyped(PortConnects.BoundaryPortVectorUnit(portName), array.getSelfClass)) +// case _ => None +// } +// connectTypedOpt.foreach { connect => +// disconnectedBoundaryPortConnections.append(connect) +// } +// } + + protected val disconnectedBlockPortConnections = + mutable.ArrayBuffer[PortConnectTyped[PortConnects.ConstraintBase]]() + block.blocks.toSeqMap.foreach { case (subBlockName, subBlock) => + subBlock.`type`.hierarchy.foreach { subBlock => + subBlock.ports.toSeqMap.foreach { case (subBlockPortName, port) => + val disconnectedConnectOpt = port.is match { + case elem.PortLike.Is.Port(port) if !connectedsByPortRef.contains(Seq(subBlockName, subBlockPortName)) => + Some(PortConnectTyped(PortConnects.BlockPort(subBlockName, subBlockPortName), port.getSelfClass)) + case elem.PortLike.Is.Bundle(port) if !connectedsByPortRef.contains(Seq(subBlockName, subBlockPortName)) => + Some(PortConnectTyped(PortConnects.BlockPort(subBlockName, subBlockPortName), port.getSelfClass)) + case elem.PortLike.Is.Array(array) => // arrays can have multiple connections + val connectOpt = connectedsByPortRef.get(Seq(subBlockName, subBlockPortName)) match { + case None => // no prior connect, add default connect + Some(PortConnects.BlockVectorSlicePort(subBlockName, subBlockPortName, None)) + case Some(connecteds) => // prior connect, see if it's a slice and more can be appended + if (connecteds.forall(_.connect.isInstanceOf[PortConnects.BlockVectorSlicePort])) { + Some(PortConnects.BlockVectorSlicePort(subBlockName, subBlockPortName, None)) + } else if (connecteds.forall(_.connect.isInstanceOf[PortConnects.BlockVectorSliceVector])) { + Some(PortConnects.BlockVectorSliceVector(subBlockName, subBlockPortName, None)) + } else { + None + } + } + connectOpt.map(connect => PortConnectTyped(connect, array.getSelfClass)) + case _ => None + } + disconnectedConnectOpt.foreach { + disconnectedBlockPortConnections.append + } + } + } + } + + // returns all connections for this block, each connection being the link name (if part of a link), + // ports attached (as ConnectTypes.Base), and list of constraints + // disconnected ports returned as a single port of BlockPort (for block ports), BoundaryPort (for boundary ports), + // and TBD for vectors (currently includes a new slice, if in a slice connection or disconnected) + val connectedGroups: Seq[(Option[String], Seq[PortConnectTyped[PortConnects.ConstraintBase]], Seq[expr.ValueExpr])] = + connectionsBuilder.toSeq.map { case (linkNameOpt, connecteds, constrs) => + (linkNameOpt, connecteds.toSeq, constrs.toSeq) + } ++ (disconnectedBoundaryPortConnections ++ disconnectedBlockPortConnections).map { connected => + (None, Seq(connected), Seq()) + } +} diff --git a/src/main/scala/edg_ide/util/ConnectBuilder.scala b/src/main/scala/edg_ide/util/ConnectBuilder.scala new file mode 100644 index 00000000..888a2362 --- /dev/null +++ b/src/main/scala/edg_ide/util/ConnectBuilder.scala @@ -0,0 +1,292 @@ +package edg_ide.util + +import com.intellij.openapi.diagnostic.Logger +import edg.EdgirUtils.SimpleLibraryPath +import edg.ExprBuilder.ValueExpr +import edg.wir.ProtoUtil.{BlockProtoToSeqMap, PortProtoToSeqMap} +import edgir.elem.elem +import edgir.elem.elem.HierarchyBlock +import edgir.expr.expr +import edgir.ref.ref + +import scala.collection.mutable + +object PortConnects { // types of connections a port attached to a connection can be a part of + // TODO materialize into constraints? - how to add tack this on to an existing IR graph + sealed trait Base { + def getPortType(container: elem.HierarchyBlock): Option[ref.LibraryPath] // retrieves the type from the container + + def topPortRef: Seq[String] // returns the top level port path, as (block, port) or (port) for boundary ports + } + + // base trait for connect that can be generated from an IR constraint, as opposed to by GUI connects + sealed trait ConstraintBase extends Base + + sealed trait PortBase extends Base // base type for any port-valued connection + sealed trait VectorBase extends Base // base type for any vector-valued connection + sealed trait AmbiguousBase extends Base // base type for any connection which can be port- or vector-valued + + protected def typeOfSinglePort(portLike: elem.PortLike): Option[ref.LibraryPath] = portLike.is match { + case elem.PortLike.Is.Port(port) => port.selfClass + case elem.PortLike.Is.Bundle(port) => port.selfClass + case _ => None + } + + protected def typeOfArrayPort(portLike: elem.PortLike): Option[ref.LibraryPath] = portLike.is match { + case elem.PortLike.Is.Array(array) => array.selfClass + case _ => None + } + + // connects to bridges show up as this, though containing no information about the boundary port + case class BlockPort(blockName: String, portName: String) extends PortBase with ConstraintBase { + override def getPortType(container: HierarchyBlock): Option[ref.LibraryPath] = { + container.blocks.toSeqMap.get(blockName).flatMap(_.`type`.hierarchy) + .flatMap(_.ports.get(portName)) + .flatMap(typeOfSinglePort) + } + override def topPortRef: Seq[String] = Seq(blockName, portName) + } + + // single exported port only, getPortType ignores the inner names + case class BoundaryPort(portName: String, innerNames: Seq[String]) extends PortBase with ConstraintBase { + override def getPortType(container: HierarchyBlock): Option[ref.LibraryPath] = { + container.ports.get(portName) + .flatMap(typeOfSinglePort) + } + override def topPortRef: Seq[String] = Seq(portName) + } + + // port array, connected as a unit; port array cannot be part of any other connection + case class BlockVectorUnit(blockName: String, portName: String) extends VectorBase with ConstraintBase { + override def getPortType(container: HierarchyBlock): Option[ref.LibraryPath] = { + container.blocks.toSeqMap.get(blockName).flatMap(_.`type`.hierarchy) + .flatMap(_.ports.get(portName)) + .flatMap(typeOfArrayPort) + } + override def topPortRef: Seq[String] = Seq(blockName, portName) + } + + sealed trait BlockVectorSliceBase extends Base { + def blockName: String + def portName: String + override def getPortType(container: HierarchyBlock): Option[ref.LibraryPath] = { // same as BlockVectorUnit case + container.blocks.toSeqMap.get(blockName).flatMap(_.`type`.hierarchy) + .flatMap(_.ports.get(portName)) + .flatMap(typeOfArrayPort) + } + override def topPortRef: Seq[String] = Seq(blockName, portName) + } + + // port-typed slice of a port array + case class BlockVectorSlicePort(blockName: String, portName: String, suggestedIndex: Option[String]) + extends PortBase with BlockVectorSliceBase with ConstraintBase {} + + // vector-typed slice of a port array, connected using allocated / requested + case class BlockVectorSliceVector(blockName: String, portName: String, suggestedIndex: Option[String]) + extends VectorBase with BlockVectorSliceBase with ConstraintBase {} + + // ambiguous slice of a port array, with no corresponding IR construct but used intuitively for GUI connections + case class BlockVectorSlice(blockName: String, portName: String, suggestedIndex: Option[String]) + extends AmbiguousBase with BlockVectorSliceBase {} + + // port array, connected as a unit; port array cannot be part of any other connection + case class BoundaryPortVectorUnit(portName: String) extends VectorBase with ConstraintBase { + override def getPortType(container: HierarchyBlock): Option[ref.LibraryPath] = { + container.ports.get(portName) + .flatMap(typeOfArrayPort) + } + override def topPortRef: Seq[String] = Seq(portName) + } + + // turns an unlowered (but optionally expanded) connect expression into a structured connect type, if the form matches + // None means the expression failed to decode + def fromConnect(constr: expr.ValueExpr): Option[Seq[ConstraintBase]] = constr.expr match { + case expr.ValueExpr.Expr.Connected(connected) => + singleBlockPortFromRef(connected.getBlockPort).map(Seq(_)) + case expr.ValueExpr.Expr.Exported(exported) => + val exterior = exported.getExteriorPort match { + case ValueExpr.Ref(Seq(portName, innerNames @ _*)) => Some(BoundaryPort(portName, innerNames)) + case _ => None // invalid / unrecognized form + } + val interior = singleBlockPortFromRef(exported.getInternalBlockPort) + (exterior, interior) match { + case (Some(exterior), Some(interior)) => Some(Seq(exterior, interior)) + case _ => None // at least one failed to decode + } + case expr.ValueExpr.Expr.ConnectedArray(connectedArray) => + vectorBlockPortFromRef(connectedArray.getBlockPort).map(Seq(_)) + case expr.ValueExpr.Expr.ExportedArray(exportedArray) => + val exterior = exportedArray.getExteriorPort match { + // exported array only supported as a unit, the compiler cannot materialize subarray indices + case ValueExpr.Ref(Seq(portName)) => Some(BoundaryPortVectorUnit(portName)) + case _ => None // invalid / unrecognized form + } + val interior = vectorBlockPortFromRef(exportedArray.getInternalBlockPort) + (exterior, interior) match { + case (Some(exterior), Some(interior)) => Some(Seq(exterior, interior)) + case _ => None // at least one failed to decode + } + case _ => None + } + + protected def singleBlockPortFromRef(ref: expr.ValueExpr): Option[ConstraintBase] = ref match { + case ValueExpr.Ref(Seq(blockName, portName)) => Some(BlockPort(blockName, portName)) + case ValueExpr.RefAllocate(Seq(blockName, portName), suggestedName) => + Some(BlockVectorSlicePort(blockName, portName, suggestedName)) + case _ => None // invalid / unrecognized form + } + + protected def vectorBlockPortFromRef(ref: expr.ValueExpr): Option[ConstraintBase] = ref match { + case ValueExpr.Ref(Seq(blockName, portName)) => Some(BlockVectorUnit(blockName, portName)) + case ValueExpr.RefAllocate(Seq(blockName, portName), suggestedName) => + Some(BlockVectorSliceVector(blockName, portName, suggestedName)) + case _ => None // invalid / unrecognized form + } +} + +object PortConnectTyped { + def fromConnect[PortConnectType <: PortConnects.Base]( + connect: PortConnectType, + container: elem.HierarchyBlock + ): Option[PortConnectTyped[PortConnectType]] = + connect.getPortType(container).map(portType => PortConnectTyped(connect, portType)) + + // like fromConnect, but returns None if any connect could not have a type determined + def fromConnectsAll[PortConnectType <: PortConnects.Base]( + connects: Seq[PortConnectType], + container: elem.HierarchyBlock + ): Option[Seq[PortConnectTyped[PortConnectType]]] = + SeqUtils.getAllDefined(connects.map { PortConnectTyped.fromConnect(_, container) }) + + // returns whether the sequence of connects with port types is a direct export + def connectsIsExport(connects: Seq[PortConnectTyped[PortConnects.Base]]): Boolean = connects match { + case Seq( + PortConnectTyped(PortConnects.BoundaryPort(_, _), boundaryType), + PortConnectTyped(PortConnects.BlockPort(_, _), blockType) + ) if boundaryType == blockType => true + case Seq( + PortConnectTyped(PortConnects.BlockPort(_, _), blockType), + PortConnectTyped(PortConnects.BoundaryPort(_, _), boundaryType) + ) if boundaryType == blockType => true + case _ => false + } +} + +case class PortConnectTyped[+PortConnectType <: PortConnects.Base]( + connect: PortConnectType, + portType: ref.LibraryPath +) {} + +object ConnectMode { // state of a connect-in-progress + trait Base + case object Port extends Base // connection between single ports, generates into link + case object Vector extends Base // connection with at least one full vector, generates into link array + case object Ambiguous extends Base // connection which can be either - but is ambiguous and cannot be created +} + +object ConnectBuilder { + private val logger = Logger.getInstance(this.getClass) + + // creates a ConnectBuilder given all the connects to a link (found externally) and context data + def apply( + container: elem.HierarchyBlock, + linkLib: elem.Link, // link is needed to determine available connects + constrs: Seq[expr.ValueExpr] + ): Option[ConnectBuilder] = { + val availableOpt = SeqUtils.getAllDefined(linkLib.ports.toSeqMap.map { case (name, portLike) => + (name, portLike.is) + } + .map { // link libraries are pre-elaboration + case (name, elem.PortLike.Is.LibElem(port)) => Some((name, false, port)) + case (name, elem.PortLike.Is.Array(array)) => array.selfClass.map((name, true, _)) + case (name, port) => + logger.warn(s"unknown port type $name = ${port.getClass} in ${linkLib.getSelfClass.toSimpleString}") + None + }.toSeq) + + val constrConnectsOpt = SeqUtils.getAllDefined(constrs.map(PortConnects.fromConnect)).map(_.flatten) + val constrConnectTypedOpt = constrConnectsOpt.flatMap { constrConnects => + PortConnectTyped.fromConnectsAll(constrConnects, container) + } + + (availableOpt, constrConnectTypedOpt) match { + case (Some(available), Some(constrConnectTyped)) => + new ConnectBuilder(linkLib, container, available, Seq(), ConnectMode.Ambiguous).append(constrConnectTyped) + case _ => + if (availableOpt.isEmpty) { + logger.warn( + s"unable to compute available ports for ${linkLib.getSelfClass.toSimpleString} in ${container.getSelfClass.toSimpleString}" + ) + } + if (constrConnectTypedOpt.isEmpty) { + logger.warn( + s"unable to compute connected ports for ${linkLib.getSelfClass.toSimpleString} in ${container.getSelfClass.toSimpleString}" + ) + } + None + } + } +} + +/** Mildly analogous to the connect builder in the frontend HDL, this starts with a link, then ports can be added. + * Immutable, added ports return a new ConnectBuilder object (if the add was successful) or None (if the ports cannot + * be added). Accounts for bridging and vectors. Works at the link level of abstraction, no special support for + * bridges. + */ +class ConnectBuilder protected ( + val linkLib: elem.Link, // library + container: elem.HierarchyBlock, + protected val availablePorts: Seq[(String, Boolean, ref.LibraryPath)], // name, is array, port type + val connected: Seq[(PortConnectTyped[PortConnects.Base], String)], // connect type, used port type, port name + val connectMode: ConnectMode.Base +) { + // Attempts to append the connections (with attached port types) to the builder, returning a new builder + // (if successful) or None (if not a legal connect). + def append(newConnects: Seq[PortConnectTyped[PortConnects.Base]]): Option[ConnectBuilder] = { + val availablePortsBuilder = availablePorts.to(mutable.ArrayBuffer) + var connectModeBuilder = connectMode + var failedToAllocate: Boolean = false + + val newConnected = newConnects.map { connectTyped => + val portName = availablePortsBuilder.indexWhere(_._3 == connectTyped.portType) match { + case index if index >= 0 => + val (portName, isArray, portType) = availablePortsBuilder(index) + connectTyped.connect match { + case _: PortConnects.PortBase => + if (connectModeBuilder == ConnectMode.Vector) { + failedToAllocate = true + } else { + connectModeBuilder = ConnectMode.Port + } + case _: PortConnects.VectorBase => + if (connectModeBuilder == ConnectMode.Port) { + failedToAllocate = true + } else { + connectModeBuilder = ConnectMode.Vector + } + case _: PortConnects.AmbiguousBase => // fine in either single or vector case + } + if (!isArray) { + availablePortsBuilder.remove(index) + } + portName + case _ => + failedToAllocate = true + "" + } + (connectTyped, portName) + } + + if (failedToAllocate) { + None + } else { + Some(new ConnectBuilder( + linkLib, + container, + availablePortsBuilder.toSeq, + connected ++ newConnected, + connectModeBuilder + )) + } + } +} diff --git a/src/main/scala/edg_ide/util/DesignFindDisconnected.scala b/src/main/scala/edg_ide/util/DesignFindDisconnected.scala index a7975d6a..05ec43ea 100644 --- a/src/main/scala/edg_ide/util/DesignFindDisconnected.scala +++ b/src/main/scala/edg_ide/util/DesignFindDisconnected.scala @@ -38,9 +38,9 @@ object DesignFindDisconnected extends DesignBlockMap[(Seq[DesignPath], Seq[Strin val myConnectedPorts = myConstrExprs .collect { // extract block side expr - case expr.ValueExpr.Expr.Connected(expr.ConnectedExpr(Some(blockExpr), Some(linkExpr), _)) => + case expr.ValueExpr.Expr.Connected(expr.ConnectedExpr(Some(blockExpr), Some(linkExpr), _, _)) => blockExpr.expr - case expr.ValueExpr.Expr.Exported(expr.ExportedExpr(Some(exteriorExpr), Some(interiorExpr), _)) => + case expr.ValueExpr.Expr.Exported(expr.ExportedExpr(Some(exteriorExpr), Some(interiorExpr), _, _)) => interiorExpr.expr } .collect { // extract steps diff --git a/src/main/scala/edg_ide/util/EdgirConnectExecutor.scala b/src/main/scala/edg_ide/util/EdgirConnectExecutor.scala new file mode 100644 index 00000000..f1782d01 --- /dev/null +++ b/src/main/scala/edg_ide/util/EdgirConnectExecutor.scala @@ -0,0 +1,106 @@ +package edg_ide.util + +import com.intellij.openapi.diagnostic.Logger +import edg.ElemBuilder.Constraint +import edg.ExprBuilder.Ref +import edg.util.NameCreator +import edgir.elem.elem +import edgir.expr.expr + +object EdgirConnectExecutor { + private val logger = Logger.getInstance(this.getClass) + + // modifies the Block at the IR level to add new connections + // this is only for visualization purposes, does not need to handle constraint prop and whatnot + def apply( + container: elem.HierarchyBlock, + linkNameOpt: Option[String], + newConnected: ConnectBuilder, + startingPort: PortConnectTyped[PortConnects.Base], + newConnects: Seq[PortConnectTyped[PortConnects.Base]] + ): Option[elem.HierarchyBlock] = { + if (newConnects.isEmpty) { // nop + Some(container) + } else if (PortConnectTyped.connectsIsExport(newConnected.connected.map(_._1))) { // export + throw new IllegalArgumentException("TODO IMPLEMENT ME new direct export connect") + } else { // everything else is a link + applyLink(container, linkNameOpt, newConnected, startingPort, newConnects) + } + } + + protected def portConnectToConstraint( + connect: PortConnectTyped[PortConnects.Base], + connectBuilder: ConnectBuilder, + linkName: String + ): Option[expr.ValueExpr] = { + val linkPortName = connectBuilder.connected.find(_._1.connect == connect.connect).map(_._2).getOrElse { + logger.error(s"portConnectToConstraint: connect $connect not found in connected") + return None + } + val constr = connect.connect match { + // TODO: this produces a constraint that might not be valid (port arrays may not have the element, needs allocate), + // but is good enough for the visualizer + case PortConnects.BlockPort(blockName, portName) => + Constraint.Connected(Ref(blockName, portName), Ref(linkName, linkPortName)) + case PortConnects.BoundaryPort(portName, _) => + throw new IllegalArgumentException("TODO IMPLEMENT ME bridge connect") + case PortConnects.BlockVectorUnit(blockName, portName) => + throw new IllegalArgumentException("TODO IMPLEMENT ME link array connect") + case PortConnects.BlockVectorSlicePort(blockName, portName, _) => + Constraint.Connected(Ref(blockName, portName), Ref(linkName, linkPortName)) // TODO allocate on block side + case PortConnects.BlockVectorSliceVector(blockName, portName, _) => + throw new IllegalArgumentException("TODO IMPLEMENT ME link array connect") + case PortConnects.BlockVectorSlice(blockPort, portName, _) => + throw new IllegalArgumentException("TODO IMPLEMENT ME link array connect") + case PortConnects.BoundaryPortVectorUnit(portName) => + throw new IllegalArgumentException("TODO IMPLEMENT ME bridge connect") + } + Some(constr) + } + + // modifies the Block to add a link, or add connections to a link + protected def applyLink( + container: elem.HierarchyBlock, + linkNameOpt: Option[String], + newConnected: ConnectBuilder, + startingPort: PortConnectTyped[PortConnects.Base], + newConnects: Seq[PortConnectTyped[PortConnects.Base]] + ): Option[elem.HierarchyBlock] = { + var containerBuilder = container + val namer = NameCreator.fromBlock(container) + val linkName = linkNameOpt match { + case Some(linkName) => linkName // link already exists, add to it + case None => // no link exists, instantiate one + val linkNewName = namer.newName("_new") + val newConstrOpt = portConnectToConstraint(startingPort, newConnected, linkNewName) + val newConstrSeq = newConstrOpt.map(newConstr => + elem.NamedValueExpr( + name = namer.newName("_new"), + value = Some(newConstr) + ) + ).toSeq + + containerBuilder = containerBuilder.update( + _.links :+= elem.NamedLinkLike( + name = linkNewName, + value = Some(elem.LinkLike(elem.LinkLike.Type.Link(newConnected.linkLib))) + ), + _.constraints :++= newConstrSeq + ) + linkNewName + } + val newConstraints = newConnects.flatMap { newConnect => + val newConstrOpt = portConnectToConstraint(newConnect, newConnected, linkName) + newConstrOpt.map(newConstr => + elem.NamedValueExpr( + name = namer.newName("_new"), + value = Some(newConstr) + ) + ) + } + containerBuilder = containerBuilder.update( + _.constraints :++= newConstraints + ) + Some(containerBuilder) + } +} diff --git a/src/main/scala/edg_ide/util/SeqUtils.scala b/src/main/scala/edg_ide/util/SeqUtils.scala new file mode 100644 index 00000000..1493abc1 --- /dev/null +++ b/src/main/scala/edg_ide/util/SeqUtils.scala @@ -0,0 +1,16 @@ +package edg_ide.util + +object SeqUtils { + // if all elements defined, returns the seq of elements, else None + def getAllDefined[T](seq: Seq[Option[T]]): Option[Seq[T]] = { + val (some, none) = seq.partitionMap { + case Some(value) => Left(value) + case None => Right(None) + } + if (none.nonEmpty) { + None + } else { + Some(some) + } + } +} diff --git a/src/test/scala/edg_ide/util/tests/BlockConnectedAnalysisTest.scala b/src/test/scala/edg_ide/util/tests/BlockConnectedAnalysisTest.scala new file mode 100644 index 00000000..e22cfc21 --- /dev/null +++ b/src/test/scala/edg_ide/util/tests/BlockConnectedAnalysisTest.scala @@ -0,0 +1,76 @@ +package edg_ide.util.tests + +import edg.ElemBuilder._ +import edg_ide.util.{BlockConnectedAnalysis, PortConnectTyped, PortConnects} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers._ + +class BlockConnectedAnalysisTest extends AnyFlatSpec { + behavior.of("BlockConnectedAnalysis") + + private val connectBuilderTest = new ConnectBuilderTest() // for shared examples + + it should "decode connections for non-array example" in { + val analysis = new BlockConnectedAnalysis(connectBuilderTest.exampleBlock) + analysis.connectedGroups.map(_._1) should equal(Seq(Some("link"), None, None, None, None)) + analysis.connectedGroups(0)._1 should equal(Some("link")) + analysis.connectedGroups(0)._2 should equal(Seq( + PortConnectTyped(PortConnects.BlockPort("source", "port"), LibraryPath("sourcePort")), + PortConnectTyped(PortConnects.BlockPort("sink0", "port"), LibraryPath("sinkPort")), + PortConnectTyped(PortConnects.BlockPort("sink1", "port"), LibraryPath("sinkPort")), + PortConnectTyped(PortConnects.BlockVectorSlicePort("sinkArray", "port", None), LibraryPath("sinkPort")), + )) + analysis.connectedGroups(0)._3 should not be empty + + analysis.connectedGroups(1)._1 should equal(None) + analysis.connectedGroups(1)._2 should equal(Seq( // export + PortConnectTyped(PortConnects.BoundaryPort("port", Seq()), LibraryPath("sourcePort")), + PortConnectTyped(PortConnects.BlockPort("exportSource", "port"), LibraryPath("sourcePort")), + )) + analysis.connectedGroups(1)._3 should not be empty + + analysis.connectedGroups(2)._1 should equal(None) + analysis.connectedGroups(2)._2 should equal(Seq( // export into bundle elt + PortConnectTyped(PortConnects.BoundaryPort("bundle", Seq("port")), LibraryPath("sourceBundle")), + PortConnectTyped(PortConnects.BlockPort("exportBundleSource", "port"), LibraryPath("sourcePort")), + )) + analysis.connectedGroups(2)._3 should not be empty + + analysis.connectedGroups(3)._1 should equal(None) + analysis.connectedGroups(3)._2 should equal(Seq( // array can have additional allocates + PortConnectTyped(PortConnects.BlockVectorSlicePort("sinkArray", "port", None), LibraryPath("sinkPort")))) + analysis.connectedGroups(3)._3 shouldBe empty + + analysis.connectedGroups(4)._1 should equal(None) + analysis.connectedGroups(4)._2 should equal(Seq( + PortConnectTyped(PortConnects.BlockPort("unusedSink", "port"), LibraryPath("sinkPort")) + )) + analysis.connectedGroups(4)._3 shouldBe empty + } + + it should "decode connections for array example" in { + val analysis = new BlockConnectedAnalysis(connectBuilderTest.exampleArrayBlock) + analysis.connectedGroups.map(_._1) should equal(Seq(Some("link"), None, None)) + + analysis.connectedGroups(0)._1 should equal(Some("link")) + analysis.connectedGroups(0)._2 should equal(Seq( + PortConnectTyped(PortConnects.BlockVectorUnit("source", "port"), LibraryPath("sourcePort")), + PortConnectTyped(PortConnects.BlockVectorUnit("sink0", "port"), LibraryPath("sinkPort")), + PortConnectTyped(PortConnects.BlockVectorUnit("sink1", "port"), LibraryPath("sinkPort")), + PortConnectTyped(PortConnects.BlockVectorSliceVector("sinkArray", "port", None), LibraryPath("sinkPort")) + )) + analysis.connectedGroups(0)._3 should not be empty + + analysis.connectedGroups(1)._1 should equal(None) + analysis.connectedGroups(1)._2 should equal(Seq( // array can have additional allocates + PortConnectTyped(PortConnects.BlockVectorSliceVector("sinkArray", "port", None), LibraryPath("sinkPort")))) + analysis.connectedGroups(1)._3 shouldBe empty + + analysis.connectedGroups(2)._1 should equal(None) + // TODO this should be an ambiguous BlockVectorUnit, but currently defaults to a VectorSlicePort +// analysis.connectedGroups(2)._2 should equal(Seq( +// PortConnectTyped(PortConnects.BlockVectorUnit("unusedSinkArray", "port"), LibraryPath("sinkPort")) +// )) + analysis.connectedGroups(2)._3 shouldBe empty + } +} diff --git a/src/test/scala/edg_ide/util/tests/ConnectBuilderTest.scala b/src/test/scala/edg_ide/util/tests/ConnectBuilderTest.scala new file mode 100644 index 00000000..b2a13501 --- /dev/null +++ b/src/test/scala/edg_ide/util/tests/ConnectBuilderTest.scala @@ -0,0 +1,375 @@ +package edg_ide.util.tests + +import edg.ElemBuilder._ +import edg.ExprBuilder.{Ref, ValInit} +import edg.wir.ProtoUtil.ConstraintProtoToSeqMap +import edg_ide.util.{ConnectBuilder, ConnectMode, PortConnectTyped, PortConnects} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers._ + +import scala.collection.SeqMap + +class ConnectBuilderTest extends AnyFlatSpec { + behavior.of("ConnectBuilder") + + val exampleBlock = Block.Block( // basic skeletal and structural example of a block with various connects + "topDesign", + ports = SeqMap( + "port" -> Port.Port("sourcePort", params = SeqMap("param" -> ValInit.Integer)), + "bundle" -> Port.Bundle( + "sourceBundle", + ports = SeqMap( + "port" -> Port.Port("sourcePort", params = SeqMap("param" -> ValInit.Integer)), + ) + ), + ), + blocks = SeqMap( + "source" -> Block.Block( + "sourceBlock", + ports = SeqMap( + "port" -> Port.Port("sourcePort", params = SeqMap("param" -> ValInit.Integer)), + ), + ), + "sink0" -> Block.Block( + "sinkBlock", + ports = SeqMap( + "port" -> Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)), + ), + ), + "sink1" -> Block.Block( + "sinkBlock", + ports = SeqMap( + "port" -> Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)), + ), + ), + "sinkArray" -> Block.Block( + "sinkArrayBlock", + ports = SeqMap( + "port" -> Port.Array( + "sinkPort", + Seq("0"), + Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)) + ), + ), + ), + "exportSource" -> Block.Block( + "sourceBlock", + ports = SeqMap( + "port" -> Port.Port("sourcePort", params = SeqMap("param" -> ValInit.Integer)), + ), + ), + "exportBundleSource" -> Block.Block( + "sourceBlock", + ports = SeqMap( + "port" -> Port.Port("sourcePort", params = SeqMap("param" -> ValInit.Integer)), + ), + ), + "unusedSink" -> Block.Block( + "sinkBlock", + ports = SeqMap( + "port" -> Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)), + ), + ), + ), + links = SeqMap( + "link" -> Link.Link("link"), + ), + constraints = SeqMap( + "sourceConnect" -> Constraint.Connected(Ref("source", "port"), Ref("link", "source")), + "sink0Connect" -> Constraint.Connected(Ref("sink0", "port"), Ref.Allocate(Ref("link", "sinks"))), + "sink1Connect" -> Constraint.Connected(Ref("sink1", "port"), Ref.Allocate(Ref("link", "sinks"))), + "sinkArrayConnect" -> Constraint.Connected( + Ref.Allocate(Ref("sinkArray", "port")), + Ref.Allocate(Ref("link", "sinks")) + ), + "sourceExport" -> Constraint.Exported(Ref("port"), Ref("exportSource", "port")), + "bundleSourceExport" -> Constraint.Exported(Ref("bundle", "port"), Ref("exportBundleSource", "port")), + ) + ).getHierarchy + + val exampleArrayBlock = Block.Block( // basic example using link arrays + "topDesign", + ports = SeqMap( + ), + blocks = SeqMap( + "source" -> Block.Block( + "sourceArrayBlock", + ports = SeqMap( + "port" -> Port.Array( + "sourcePort", + Seq("0", "1"), + Port.Port("sourcePort", params = SeqMap("param" -> ValInit.Integer)) + ), + ), + ), + "sink0" -> Block.Block( + "sinkArrayBlock", + ports = SeqMap( + "port" -> Port.Array( + "sinkPort", + Seq("0", "1"), + Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)) + ), + ), + ), + "sink1" -> Block.Block( + "sinkArrayBlock", + ports = SeqMap( + "port" -> Port.Array( + "sinkPort", + Seq("0", "1"), + Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)) + ), + ), + ), + "sinkArray" -> Block.Block( + "sinkArrayBlock", + ports = SeqMap( + "port" -> Port.Array( + "sinkPort", + Seq("0", "1"), + Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)) + ), + ), + ), + "unusedSinkArray" -> Block.Block( + "sinkArrayBlock", + ports = SeqMap( + "port" -> Port.Array( + "sinkPort", + Seq("0", "1"), + Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)) + ), + ), + ), + ), + links = SeqMap( + "link" -> Link.Array("link"), + ), + constraints = SeqMap( + "sourceConnect" -> Constraint.ConnectedArray(Ref("source", "port"), Ref("link", "source")), + "sink0Connect" -> Constraint.ConnectedArray(Ref("sink0", "port"), Ref.Allocate(Ref("link", "sinks"))), + "sink1Connect" -> Constraint.ConnectedArray(Ref("sink1", "port"), Ref.Allocate(Ref("link", "sinks"))), + "sinkArrayConnect" -> Constraint.ConnectedArray( + Ref.Allocate(Ref("sinkArray", "port")), + Ref.Allocate(Ref("link", "sinks")) + ), + ) + ).getHierarchy + + val exampleLink = Link.Link( + "link", + ports = SeqMap( + "source" -> Port.Library("sourcePort"), + "sinks" -> Port.Array("sinkPort"), + ), + params = SeqMap( + "param" -> ValInit.Integer + ), + ).getLink + + it should "decode connections and get types" in { + // basic connection forms + PortConnects.fromConnect(exampleBlock.constraints.toSeqMap("sourceConnect")) should equal(Some(Seq( + PortConnects.BlockPort("source", "port") + ))) + PortConnects.BlockPort("source", "port").getPortType(exampleBlock) should equal(Some(LibraryPath("sourcePort"))) + + PortConnects.fromConnect(exampleBlock.constraints.toSeqMap("sink0Connect")) should equal(Some(Seq( + PortConnects.BlockPort("sink0", "port") + ))) + PortConnects.BlockPort("sink0", "port").getPortType(exampleBlock) should equal(Some(LibraryPath("sinkPort"))) + + PortConnects.fromConnect(exampleBlock.constraints.toSeqMap("sourceExport")) should equal(Some(Seq( + PortConnects.BoundaryPort("port", Seq()), + PortConnects.BlockPort("exportSource", "port") + ))) + PortConnects.BoundaryPort("port", Seq()).getPortType(exampleBlock) should equal(Some(LibraryPath("sourcePort"))) + PortConnects.BlockPort("exportSource", "port").getPortType(exampleBlock) should equal( + Some(LibraryPath("sourcePort")) + ) + + // export into bundle component / vector element + PortConnects.fromConnect(exampleBlock.constraints.toSeqMap("bundleSourceExport")) should equal(Some(Seq( + PortConnects.BoundaryPort("bundle", Seq("port")), + PortConnects.BlockPort("exportBundleSource", "port") + ))) + PortConnects.BoundaryPort("bundle", Seq("port")).getPortType(exampleBlock) should equal( + Some(LibraryPath("sourceBundle")) + ) + PortConnects.BlockPort("exportBundleSource", "port").getPortType(exampleBlock) should equal( + Some(LibraryPath("sourcePort")) + ) + } + + it should "build valid connects from empty, starting with a port" in { + val emptyConnect = ConnectBuilder(exampleBlock, exampleLink, Seq()) + emptyConnect should not be empty + emptyConnect.get.connectMode should equal(ConnectMode.Ambiguous) + + val sourceConnect = emptyConnect.get.append(Seq( + PortConnectTyped(PortConnects.BlockPort("source", "port"), LibraryPath("sourcePort")) + )) + sourceConnect should not be empty + sourceConnect.get.connectMode should equal(ConnectMode.Port) + + val sink0Connect = sourceConnect.get.append(Seq( + PortConnectTyped(PortConnects.BlockPort("sink0", "port"), LibraryPath("sinkPort")) + )) + sink0Connect should not be empty + sink0Connect.get.connectMode should equal(ConnectMode.Port) + sink0Connect.get.connected should equal(Seq( + (PortConnectTyped(PortConnects.BlockPort("source", "port"), LibraryPath("sourcePort")), "source"), + (PortConnectTyped(PortConnects.BlockPort("sink0", "port"), LibraryPath("sinkPort")), "sinks") + )) + } + + it should "build valid port connects from empty, starting with an ambiguous vector" in { + val emptyConnect = ConnectBuilder(exampleBlock, exampleLink, Seq()) + val sliceConnect = emptyConnect.get.append(Seq( + PortConnectTyped(PortConnects.BlockVectorSlice("sinkArray", "port", None), LibraryPath("sinkPort")) + )) + sliceConnect should not be empty + sliceConnect.get.connectMode should equal(ConnectMode.Ambiguous) // stays ambiguous + + val sink0Connect = sliceConnect.get.append(Seq( + PortConnectTyped(PortConnects.BlockPort("sink0", "port"), LibraryPath("sinkPort")) + )) + sink0Connect should not be empty + sink0Connect.get.connectMode should equal(ConnectMode.Port) // connection type resolved here + sink0Connect.get.connected should equal(Seq( + (PortConnectTyped(PortConnects.BlockVectorSlice("sinkArray", "port", None), LibraryPath("sinkPort")), "sinks"), + (PortConnectTyped(PortConnects.BlockPort("sink0", "port"), LibraryPath("sinkPort")), "sinks") + )) + } + + it should "build valid array connects from empty, starting with an ambiguous vector" in { + val emptyConnect = ConnectBuilder(exampleArrayBlock, exampleLink, Seq()) + val sliceConnect = emptyConnect.get.append(Seq( + PortConnectTyped(PortConnects.BlockVectorSlice("sinkArray", "port", None), LibraryPath("sinkPort")) + )) + + val sinkArrayConnect = sliceConnect.get.append(Seq( + PortConnectTyped(PortConnects.BlockVectorUnit("sink0", "port"), LibraryPath("sinkPort")) + )) + sinkArrayConnect should not be empty + sinkArrayConnect.get.connectMode should equal(ConnectMode.Vector) // connection type resolved here + sinkArrayConnect.get.connected should equal(Seq( + (PortConnectTyped(PortConnects.BlockVectorSlice("sinkArray", "port", None), LibraryPath("sinkPort")), "sinks"), + (PortConnectTyped(PortConnects.BlockVectorUnit("sink0", "port"), LibraryPath("sinkPort")), "sinks") + )) + } + + it should "build valid connects with multiple allocations to an array" in { + val connect = ConnectBuilder( + exampleBlock, + exampleLink, + Seq( + exampleBlock.constraints.toSeqMap("sink0Connect"), + exampleBlock.constraints.toSeqMap("sink1Connect"), + exampleBlock.constraints.toSeqMap("sinkArrayConnect") + ) + ) + connect should not be empty + connect.get.connectMode should equal(ConnectMode.Port) + } + + it should "build valid array connects with multiple allocations to an array" in { + val connect = ConnectBuilder( + exampleArrayBlock, + exampleLink, + Seq( + exampleArrayBlock.constraints.toSeqMap("sink0Connect"), + exampleArrayBlock.constraints.toSeqMap("sink1Connect"), + exampleArrayBlock.constraints.toSeqMap("sinkArrayConnect") + ) + ) + connect should not be empty + connect.get.connectMode should equal(ConnectMode.Vector) + } + + it should "not overallocate single connects" in { + val sourceConnect = ConnectBuilder( + exampleBlock, + exampleLink, + Seq(exampleBlock.constraints.toSeqMap("sourceConnect")) + ).get + sourceConnect.append(Seq( + PortConnectTyped(PortConnects.BlockPort("source", "port"), LibraryPath("sourcePort")) + )) shouldBe empty + } + + it should "not overallocate vector connects" in { + val sourceConnect = ConnectBuilder( + exampleArrayBlock, + exampleLink, + Seq(exampleArrayBlock.constraints.toSeqMap("sourceConnect")) + ).get + sourceConnect.append(Seq( + PortConnectTyped(PortConnects.BlockVectorSlice("source", "port", None), LibraryPath("sourcePort")) + )) shouldBe empty + } + + it should "not mix port and vector connects" in { + val sourceConnect = ConnectBuilder( + exampleArrayBlock, + exampleLink, + Seq(exampleArrayBlock.constraints.toSeqMap("sourceConnect")) + ).get + sourceConnect.append(Seq( + PortConnectTyped(PortConnects.BlockPort("sink", "port"), LibraryPath("sinkPort")) + )) shouldBe empty + } + + it should "allow appending vector slices in port mode" in { + val sourceConnect = ConnectBuilder( + exampleBlock, + exampleLink, + Seq( + exampleBlock.constraints.toSeqMap("sourceConnect"), + exampleBlock.constraints.toSeqMap("sink0Connect") + ) + ).get + val sinkConnected = sourceConnect.append(Seq( + PortConnectTyped(PortConnects.BlockPort("sink1", "port"), LibraryPath("sinkPort")) + )) + sinkConnected should not be empty + val sliceConnected = sinkConnected.get.append(Seq( + PortConnectTyped(PortConnects.BlockVectorSlice("sinkArray", "port", None), LibraryPath("sinkPort")) + )) + sliceConnected should not be empty + sliceConnected.get.connectMode should equal(ConnectMode.Port) + sliceConnected.get.connected should equal(Seq( + (PortConnectTyped(PortConnects.BlockPort("source", "port"), LibraryPath("sourcePort")), "source"), + (PortConnectTyped(PortConnects.BlockPort("sink0", "port"), LibraryPath("sinkPort")), "sinks"), + (PortConnectTyped(PortConnects.BlockPort("sink1", "port"), LibraryPath("sinkPort")), "sinks"), + (PortConnectTyped(PortConnects.BlockVectorSlice("sinkArray", "port", None), LibraryPath("sinkPort")), "sinks"), + )) + } + + it should "allow appending vector slices in vector mode" in { + val sourceConnect = ConnectBuilder( + exampleArrayBlock, + exampleLink, + Seq( + exampleArrayBlock.constraints.toSeqMap("sourceConnect"), + exampleArrayBlock.constraints.toSeqMap("sink0Connect") + ) + ).get + val sinkConnected = sourceConnect.append(Seq( + PortConnectTyped(PortConnects.BlockVectorUnit("sink1", "port"), LibraryPath("sinkPort")) + )) + sinkConnected should not be empty + val sliceConnected = sinkConnected.get.append(Seq( + PortConnectTyped(PortConnects.BlockVectorSlice("sinkArray", "port", None), LibraryPath("sinkPort")) + )) + sliceConnected should not be empty + sliceConnected.get.connectMode should equal(ConnectMode.Vector) + sliceConnected.get.connected should equal(Seq( + (PortConnectTyped(PortConnects.BlockVectorUnit("source", "port"), LibraryPath("sourcePort")), "source"), + (PortConnectTyped(PortConnects.BlockVectorUnit("sink0", "port"), LibraryPath("sinkPort")), "sinks"), + (PortConnectTyped(PortConnects.BlockVectorUnit("sink1", "port"), LibraryPath("sinkPort")), "sinks"), + (PortConnectTyped(PortConnects.BlockVectorSlice("sinkArray", "port", None), LibraryPath("sinkPort")), "sinks"), + )) + } +} diff --git a/src/test/scala/edg_ide/util/tests/EdgirConnectExecutorTest.scala b/src/test/scala/edg_ide/util/tests/EdgirConnectExecutorTest.scala new file mode 100644 index 00000000..c1467cf5 --- /dev/null +++ b/src/test/scala/edg_ide/util/tests/EdgirConnectExecutorTest.scala @@ -0,0 +1,173 @@ +package edg_ide.util.tests + +import edgir.elem.elem +import edg.ElemBuilder._ +import edg.ExprBuilder.{Ref, ValInit} +import edg.wir.ProtoUtil.{ConstraintProtoToSeqMap, LinkProtoToSeqMap} +import edg_ide.util.{BlockConnectedAnalysis, ConnectBuilder, EdgirConnectExecutor, PortConnectTyped, PortConnects} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers._ + +import scala.collection.SeqMap + +class EdgirConnectExecutorTest extends AnyFlatSpec { + behavior.of("EdgirConnectExecutor") + + private val connectBuilderTest = new ConnectBuilderTest() // for shared examples + + val emptyBlock = Block.Block( + "topDesign", + ports = SeqMap( + "port" -> Port.Port("sourcePort", params = SeqMap("param" -> ValInit.Integer)), + ), + blocks = SeqMap( + "source" -> Block.Block( + "sourceBlock", + ports = SeqMap( + "port" -> Port.Port("sourcePort", params = SeqMap("param" -> ValInit.Integer)), + ), + ), + "sink0" -> Block.Block( + "sinkBlock", + ports = SeqMap( + "port" -> Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)), + ), + ), + "sink1" -> Block.Block( + "sinkBlock", + ports = SeqMap( + "port" -> Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)), + ), + ), + "sinkArray" -> Block.Block( + "sinkArrayBlock", + ports = SeqMap( + "port" -> Port.Array( + "sinkPort", + Seq("0"), + Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)) + ), + ), + ), + ), + ).getHierarchy + + it should "create a new link, source first" in { + val startingPort = PortConnectTyped(PortConnects.BlockPort("source", "port"), LibraryPath("sourcePort")) + val newConnects = Seq( + PortConnectTyped(PortConnects.BlockPort("sink0", "port"), LibraryPath("sinkPort")), + PortConnectTyped(PortConnects.BlockPort("sink1", "port"), LibraryPath("sinkPort")), + ) + val newConnected = ConnectBuilder(emptyBlock, connectBuilderTest.exampleLink, Seq()).get + .append(startingPort +: newConnects).get + val connected = EdgirConnectExecutor(emptyBlock, None, newConnected, startingPort, newConnects).get + connected.links.toSeqMap should equal(SeqMap( + "_new" -> elem.LinkLike(elem.LinkLike.Type.Link(connectBuilderTest.exampleLink)) + )) + connected.constraints.toSeqMap should equal(SeqMap( + "_new2" -> Constraint.Connected(Ref("source", "port"), Ref("_new", "source")), + "_new3" -> Constraint.Connected(Ref("sink0", "port"), Ref("_new", "sinks")), // TODO link allocate + "_new4" -> Constraint.Connected(Ref("sink1", "port"), Ref("_new", "sinks")), // TODO link allocate + )) + } + + it should "create a new link, sink first" in { + val startingPort = PortConnectTyped(PortConnects.BlockPort("sink0", "port"), LibraryPath("sinkPort")) + val newConnects = Seq( + PortConnectTyped(PortConnects.BlockPort("source", "port"), LibraryPath("sourcePort")), + PortConnectTyped(PortConnects.BlockPort("sink1", "port"), LibraryPath("sinkPort")), + ) + val newConnected = ConnectBuilder(emptyBlock, connectBuilderTest.exampleLink, Seq()).get + .append(startingPort +: newConnects).get + val connected = EdgirConnectExecutor(emptyBlock, None, newConnected, startingPort, newConnects).get + connected.links.toSeqMap should equal(SeqMap( + "_new" -> elem.LinkLike(elem.LinkLike.Type.Link(connectBuilderTest.exampleLink)) + )) + connected.constraints.toSeqMap should equal(SeqMap( + "_new2" -> Constraint.Connected(Ref("sink0", "port"), Ref("_new", "sinks")), // TODO link allocate + "_new3" -> Constraint.Connected(Ref("source", "port"), Ref("_new", "source")), + "_new4" -> Constraint.Connected(Ref("sink1", "port"), Ref("_new", "sinks")), // TODO link allocate + )) + } + + it should "create a new link, with vector slice port" in { + val startingPort = PortConnectTyped(PortConnects.BlockPort("source", "port"), LibraryPath("sourcePort")) + val newConnects = Seq( + PortConnectTyped(PortConnects.BlockVectorSlicePort("sinkArray", "port", None), LibraryPath("sinkPort")), + ) + val newConnected = ConnectBuilder(emptyBlock, connectBuilderTest.exampleLink, Seq()).get + .append(startingPort +: newConnects).get + val connected = EdgirConnectExecutor(emptyBlock, None, newConnected, startingPort, newConnects).get + connected.links.toSeqMap should equal(SeqMap( + "_new" -> elem.LinkLike(elem.LinkLike.Type.Link(connectBuilderTest.exampleLink)) + )) + connected.constraints.toSeqMap should equal(SeqMap( + "_new2" -> Constraint.Connected(Ref("source", "port"), Ref("_new", "source")), + "_new3" -> Constraint.Connected(Ref("sinkArray", "port"), Ref("_new", "sinks")), // TODO link and block allocate + )) + } + + val connectedBlock = Block.Block( + "topDesign", + ports = SeqMap( + "port" -> Port.Port("sourcePort", params = SeqMap("param" -> ValInit.Integer)), + ), + blocks = SeqMap( + "source" -> Block.Block( + "sourceBlock", + ports = SeqMap( + "port" -> Port.Port("sourcePort", params = SeqMap("param" -> ValInit.Integer)), + ), + ), + "sink0" -> Block.Block( + "sinkBlock", + ports = SeqMap( + "port" -> Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)), + ), + ), + "sink1" -> Block.Block( + "sinkBlock", + ports = SeqMap( + "port" -> Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)), + ), + ), + "sinkArray" -> Block.Block( + "sinkArrayBlock", + ports = SeqMap( + "port" -> Port.Array( + "sinkPort", + Seq("0"), + Port.Port("sinkPort", params = SeqMap("param" -> ValInit.Integer)) + ), + ), + ), + ), + links = SeqMap( + "link" -> Link.Link("link"), + ), + constraints = SeqMap( + "sourceConnect" -> Constraint.Connected(Ref("source", "port"), Ref("link", "source")), + "sink0Connect" -> Constraint.Connected(Ref("sink0", "port"), Ref.Allocate(Ref("link", "sinks"))), + ), + ).getHierarchy + + it should "append to existing link" in { + val startingPort = PortConnectTyped(PortConnects.BlockPort("source", "port"), LibraryPath("sourcePort")) + val newConnects = Seq( + PortConnectTyped(PortConnects.BlockPort("sink1", "port"), LibraryPath("sinkPort")), + ) + val newConnected = ConnectBuilder( + connectedBlock, + connectBuilderTest.exampleLink, + connectedBlock.constraints.toSeqMap.values.toSeq + ).get + .append(newConnects).get + val connected = EdgirConnectExecutor(connectedBlock, Some("link"), newConnected, startingPort, newConnects).get + connected.links.toSeqMap.keys should equal(Set("link")) // no new port should be added + connected.constraints.toSeqMap should equal(SeqMap( + "sourceConnect" -> Constraint.Connected(Ref("source", "port"), Ref("link", "source")), + "sink0Connect" -> Constraint.Connected(Ref("sink0", "port"), Ref.Allocate(Ref("link", "sinks"))), + "_new" -> Constraint.Connected(Ref("sink1", "port"), Ref("link", "sinks")), // TODO should be allocate + )) + } +} diff --git a/src/test/scala/edg_ide/util/tests/SeqUtilsTest.scala b/src/test/scala/edg_ide/util/tests/SeqUtilsTest.scala new file mode 100644 index 00000000..150f40d1 --- /dev/null +++ b/src/test/scala/edg_ide/util/tests/SeqUtilsTest.scala @@ -0,0 +1,24 @@ +package edg_ide.util.tests + +import edg_ide.util.SeqUtils +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers._ + +class SeqUtilsTest extends AnyFlatSpec { + behavior.of("IterableUtils") + + it should "work for Some" in { + SeqUtils.getAllDefined(Seq(Some(1))) should equal(Some(Seq(1))) + SeqUtils.getAllDefined(Seq(Some(1), Some(2))) should equal(Some(Seq(1, 2))) + } + + it should "work for None" in { + SeqUtils.getAllDefined(Seq(None)) should equal(None) + SeqUtils.getAllDefined(Seq(None, None)) should equal(None) + } + + it should "work for mixed" in { + SeqUtils.getAllDefined(Seq(Some(1), None)) should equal(None) + SeqUtils.getAllDefined(Seq(Some(1), None, Some(3))) should equal(None) + } +}