Skip to content

Commit

Permalink
Refactor websocket (#4879)
Browse files Browse the repository at this point in the history
Co-authored-by: sp.wack <[email protected]>
  • Loading branch information
tofarr and amanape authored Nov 11, 2024
1 parent 79492b6 commit a1a9d2f
Show file tree
Hide file tree
Showing 19 changed files with 487 additions and 465 deletions.
14 changes: 7 additions & 7 deletions frontend/__tests__/components/chat/chat-interface.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ describe("Empty state", () => {
send: vi.fn(),
}));

const { useSocket: useSocketMock } = vi.hoisted(() => ({
useSocket: vi.fn(() => ({ send: sendMock, runtimeActive: true })),
const { useWsClient: useWsClientMock } = vi.hoisted(() => ({
useWsClient: vi.fn(() => ({ send: sendMock, runtimeActive: true })),
}));

beforeAll(() => {
vi.mock("#/context/socket", async (importActual) => ({
...(await importActual<typeof import("#/context/socket")>()),
useSocket: useSocketMock,
...(await importActual<typeof import("#/context/ws-client-provider")>()),
useWsClient: useWsClientMock,
}));
});

Expand Down Expand Up @@ -77,7 +77,7 @@ describe("Empty state", () => {
"should load the a user message to the input when selecting",
async () => {
// this is to test that the message is in the UI before the socket is called
useSocketMock.mockImplementation(() => ({
useWsClientMock.mockImplementation(() => ({
send: sendMock,
runtimeActive: false, // mock an inactive runtime setup
}));
Expand Down Expand Up @@ -106,7 +106,7 @@ describe("Empty state", () => {
it.fails(
"should send the message to the socket only if the runtime is active",
async () => {
useSocketMock.mockImplementation(() => ({
useWsClientMock.mockImplementation(() => ({
send: sendMock,
runtimeActive: false, // mock an inactive runtime setup
}));
Expand All @@ -123,7 +123,7 @@ describe("Empty state", () => {
await user.click(displayedSuggestions[0]);
expect(sendMock).not.toHaveBeenCalled();

useSocketMock.mockImplementation(() => ({
useWsClientMock.mockImplementation(() => ({
send: sendMock,
runtimeActive: true, // mock an active runtime setup
}));
Expand Down
20 changes: 16 additions & 4 deletions frontend/__tests__/hooks/use-terminal.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ import { beforeAll, describe, expect, it, vi } from "vitest";
import { render } from "@testing-library/react";
import { afterEach } from "node:test";
import { useTerminal } from "#/hooks/useTerminal";
import { SocketProvider } from "#/context/socket";
import { Command } from "#/state/commandSlice";
import { WsClientProvider } from "#/context/ws-client-provider";
import { ReactNode } from "react";

interface TestTerminalComponentProps {
commands: Command[];
Expand All @@ -18,6 +19,17 @@ function TestTerminalComponent({
return <div ref={ref} />;
}

interface WrapperProps {
children: ReactNode;
}


function Wrapper({children}: WrapperProps) {
return (
<WsClientProvider enabled={true} token="NO_JWT" ghToken="NO_GITHUB" settings={null}>{children}</WsClientProvider>
)
}

describe("useTerminal", () => {
const mockTerminal = vi.hoisted(() => ({
loadAddon: vi.fn(),
Expand Down Expand Up @@ -50,7 +62,7 @@ describe("useTerminal", () => {

it("should render", () => {
render(<TestTerminalComponent commands={[]} secrets={[]} />, {
wrapper: SocketProvider,
wrapper: Wrapper,
});
});

Expand All @@ -61,7 +73,7 @@ describe("useTerminal", () => {
];

render(<TestTerminalComponent commands={commands} secrets={[]} />, {
wrapper: SocketProvider,
wrapper: Wrapper,
});

expect(mockTerminal.writeln).toHaveBeenNthCalledWith(1, "echo hello");
Expand All @@ -85,7 +97,7 @@ describe("useTerminal", () => {
secrets={[secret, anotherSecret]}
/>,
{
wrapper: SocketProvider,
wrapper: Wrapper,
},
);

Expand Down
4 changes: 2 additions & 2 deletions frontend/src/components/AgentControlBar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import PlayIcon from "#/assets/play";
import { generateAgentStateChangeEvent } from "#/services/agentStateService";
import { RootState } from "#/store";
import AgentState from "#/types/AgentState";
import { useSocket } from "#/context/socket";
import { useWsClient } from "#/context/ws-client-provider";

const IgnoreTaskStateMap: Record<string, AgentState[]> = {
[AgentState.PAUSED]: [
Expand Down Expand Up @@ -72,7 +72,7 @@ function ActionButton({
}

function AgentControlBar() {
const { send } = useSocket();
const { send } = useWsClient();
const { curAgentState } = useSelector((state: RootState) => state.agent);

const handleAction = (action: AgentState) => {
Expand Down
4 changes: 2 additions & 2 deletions frontend/src/components/chat-interface.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { useDispatch, useSelector } from "react-redux";
import React from "react";
import posthog from "posthog-js";
import { useSocket } from "#/context/socket";
import { convertImageToBase64 } from "#/utils/convert-image-to-base-64";
import { ChatMessage } from "./chat-message";
import { FeedbackActions } from "./feedback-actions";
Expand All @@ -22,13 +21,14 @@ import { ScrollToBottomButton } from "./scroll-to-bottom-button";
import { Suggestions } from "./suggestions";
import { SUGGESTIONS } from "#/utils/suggestions";
import BuildIt from "#/icons/build-it.svg?react";
import { useWsClient } from "#/context/ws-client-provider";

const isErrorMessage = (
message: Message | ErrorMessage,
): message is ErrorMessage => "error" in message;

export function ChatInterface() {
const { send } = useSocket();
const { send } = useWsClient();
const dispatch = useDispatch();
const scrollRef = React.useRef<HTMLDivElement>(null);
const { scrollDomToBottom, onChatBodyScroll, hitBottom } =
Expand Down
4 changes: 2 additions & 2 deletions frontend/src/components/chat/ConfirmationButtons.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import RejectIcon from "#/assets/reject";
import { I18nKey } from "#/i18n/declaration";
import AgentState from "#/types/AgentState";
import { generateAgentStateChangeEvent } from "#/services/agentStateService";
import { useSocket } from "#/context/socket";
import { useWsClient } from "#/context/ws-client-provider";

interface ActionTooltipProps {
type: "confirm" | "reject";
Expand Down Expand Up @@ -37,7 +37,7 @@ function ActionTooltip({ type, onClick }: ActionTooltipProps) {

function ConfirmationButtons() {
const { t } = useTranslation();
const { send } = useSocket();
const { send } = useWsClient();

const handleStateChange = (state: AgentState) => {
const event = generateAgentStateChangeEvent(state);
Expand Down
188 changes: 188 additions & 0 deletions frontend/src/components/event-handler.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import React from "react";
import {
useFetcher,
useLoaderData,
useRouteLoaderData,
} from "@remix-run/react";
import { useDispatch, useSelector } from "react-redux";
import toast from "react-hot-toast";

import posthog from "posthog-js";
import {
useWsClient,
WsClientProviderStatus,
} from "#/context/ws-client-provider";
import { ErrorObservation } from "#/types/core/observations";
import { addErrorMessage, addUserMessage } from "#/state/chatSlice";
import { handleAssistantMessage } from "#/services/actions";
import {
getCloneRepoCommand,
getGitHubTokenCommand,
} from "#/services/terminalService";
import {
clearFiles,
clearSelectedRepository,
setImportedProjectZip,
} from "#/state/initial-query-slice";
import { clientLoader as appClientLoader } from "#/routes/_oh.app";
import store, { RootState } from "#/store";
import { createChatMessage } from "#/services/chatService";
import { clientLoader as rootClientLoader } from "#/routes/_oh";
import { isGitHubErrorReponse } from "#/api/github";
import OpenHands from "#/api/open-hands";
import { base64ToBlob } from "#/utils/base64-to-blob";
import { setCurrentAgentState } from "#/state/agentSlice";
import AgentState from "#/types/AgentState";
import { getSettings } from "#/services/settings";

interface ServerError {
error: boolean | string;
message: string;
[key: string]: unknown;
}

const isServerError = (data: object): data is ServerError => "error" in data;

const isErrorObservation = (data: object): data is ErrorObservation =>
"observation" in data && data.observation === "error";

export function EventHandler({ children }: React.PropsWithChildren) {
const { events, status, send } = useWsClient();
const statusRef = React.useRef<WsClientProviderStatus | null>(null);
const runtimeActive = status === WsClientProviderStatus.ACTIVE;
const fetcher = useFetcher();
const dispatch = useDispatch();
const { files, importedProjectZip } = useSelector(
(state: RootState) => state.initalQuery,
);
const { ghToken, repo } = useLoaderData<typeof appClientLoader>();
const initialQueryRef = React.useRef<string | null>(
store.getState().initalQuery.initialQuery,
);

const sendInitialQuery = (query: string, base64Files: string[]) => {
const timestamp = new Date().toISOString();
send(createChatMessage(query, base64Files, timestamp));
};
const data = useRouteLoaderData<typeof rootClientLoader>("routes/_oh");
const userId = React.useMemo(() => {
if (data?.user && !isGitHubErrorReponse(data.user)) return data.user.id;
return null;
}, [data?.user]);
const userSettings = getSettings();

React.useEffect(() => {
if (!events.length) {
return;
}
const event = events[events.length - 1];
if (event.token) {
fetcher.submit({ token: event.token as string }, { method: "post" });
return;
}

if (isServerError(event)) {
if (event.error_code === 401) {
toast.error("Session expired.");
fetcher.submit({}, { method: "POST", action: "/end-session" });
return;
}

if (typeof event.error === "string") {
toast.error(event.error);
} else {
toast.error(event.message);
}
return;
}

if (isErrorObservation(event)) {
dispatch(
addErrorMessage({
id: event.extras?.error_id,
message: event.message,
}),
);
return;
}
handleAssistantMessage(event);
}, [events.length]);

React.useEffect(() => {
if (statusRef.current === status) {
return; // This is a check because of strict mode - if the status did not change, don't do anything
}
statusRef.current = status;
const initialQuery = initialQueryRef.current;

if (status === WsClientProviderStatus.ACTIVE) {
let additionalInfo = "";
if (ghToken && repo) {
send(getCloneRepoCommand(ghToken, repo));
additionalInfo = `Repository ${repo} has been cloned to /workspace. Please check the /workspace for files.`;
dispatch(clearSelectedRepository()); // reset selected repository; maybe better to move this to '/'?
}
// if there's an uploaded project zip, add it to the chat
else if (importedProjectZip) {
additionalInfo = `Files have been uploaded. Please check the /workspace for files.`;
}

if (initialQuery) {
if (additionalInfo) {
sendInitialQuery(`${initialQuery}\n\n[${additionalInfo}]`, files);
} else {
sendInitialQuery(initialQuery, files);
}
dispatch(clearFiles()); // reset selected files
initialQueryRef.current = null;
}
}

if (status === WsClientProviderStatus.OPENING && initialQuery) {
dispatch(
addUserMessage({
content: initialQuery,
imageUrls: files,
timestamp: new Date().toISOString(),
}),
);
}

if (status === WsClientProviderStatus.STOPPED) {
store.dispatch(setCurrentAgentState(AgentState.STOPPED));
}
}, [status]);

React.useEffect(() => {
if (runtimeActive && userId && ghToken) {
// Export if the user valid, this could happen mid-session so it is handled here
send(getGitHubTokenCommand(ghToken));
}
}, [userId, ghToken, runtimeActive]);

React.useEffect(() => {
(async () => {
if (runtimeActive && importedProjectZip) {
// upload files action
try {
const blob = base64ToBlob(importedProjectZip);
const file = new File([blob], "imported-project.zip", {
type: blob.type,
});
await OpenHands.uploadFiles([file]);
dispatch(setImportedProjectZip(null));
} catch (error) {
toast.error("Failed to upload project files.");
}
}
})();
}, [runtimeActive, importedProjectZip]);

React.useEffect(() => {
if (userSettings.LLM_API_KEY) {
posthog.capture("user_activated");
}
}, [userSettings.LLM_API_KEY]);

return children;
}
4 changes: 2 additions & 2 deletions frontend/src/components/project-menu/ProjectMenuCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import EllipsisH from "#/icons/ellipsis-h.svg?react";
import { ModalBackdrop } from "../modals/modal-backdrop";
import { ConnectToGitHubModal } from "../modals/connect-to-github-modal";
import { addUserMessage } from "#/state/chatSlice";
import { useSocket } from "#/context/socket";
import { createChatMessage } from "#/services/chatService";
import { ProjectMenuCardContextMenu } from "./project.menu-card-context-menu";
import { ProjectMenuDetailsPlaceholder } from "./project-menu-details-placeholder";
import { ProjectMenuDetails } from "./project-menu-details";
import { downloadWorkspace } from "#/utils/download-workspace";
import { LoadingSpinner } from "../modals/LoadingProject";
import { useWsClient } from "#/context/ws-client-provider";

interface ProjectMenuCardProps {
isConnectedToGitHub: boolean;
Expand All @@ -27,7 +27,7 @@ export function ProjectMenuCard({
isConnectedToGitHub,
githubData,
}: ProjectMenuCardProps) {
const { send } = useSocket();
const { send } = useWsClient();
const dispatch = useDispatch();

const [contextMenuIsOpen, setContextMenuIsOpen] = React.useState(false);
Expand Down
Loading

0 comments on commit a1a9d2f

Please sign in to comment.