diff --git a/app/packages/core/src/components/Modal/ImaVidLooker.tsx b/app/packages/core/src/components/Modal/ImaVidLooker.tsx index 6707202a8a..952d6a1179 100644 --- a/app/packages/core/src/components/Modal/ImaVidLooker.tsx +++ b/app/packages/core/src/components/Modal/ImaVidLooker.tsx @@ -17,13 +17,14 @@ import React, { import { useErrorHandler } from "react-error-boundary"; import { useRecoilValue, useSetRecoilState } from "recoil"; import { v4 as uuid } from "uuid"; -import { useInitializeImaVidSubscriptions, useModalContext } from "./hooks"; +import { useClearSelectedLabels, useShowOverlays } from "./ModalLooker"; import { - shortcutToHelpItems, - useClearSelectedLabels, + useInitializeImaVidSubscriptions, useLookerOptionsUpdate, - useShowOverlays, -} from "./ModalLooker"; + useModalContext, +} from "./hooks"; +import useKeyEvents from "./use-key-events"; +import { shortcutToHelpItems } from "./utils"; interface ImaVidLookerReactProps { sample: fos.ModalSample; @@ -132,19 +133,7 @@ export const ImaVidLookerReact = React.memo( useEventHandler(looker, "clear", useClearSelectedLabels()); - const hoveredSample = useRecoilValue(fos.hoveredSample); - - useEffect(() => { - const hoveredSampleId = hoveredSample?._id; - looker.updater((state) => ({ - ...state, - // todo: always setting it to true might not be wise - shouldHandleKeyEvents: true, - options: { - ...state.options, - }, - })); - }, [hoveredSample, sample, looker]); + useKeyEvents(initialRef, sample._id, looker); const ref = useRef(null); useEffect(() => { diff --git a/app/packages/core/src/components/Modal/ModalLooker.tsx b/app/packages/core/src/components/Modal/ModalLooker.tsx index c18eb5e048..08b52c7ca0 100644 --- a/app/packages/core/src/components/Modal/ModalLooker.tsx +++ b/app/packages/core/src/components/Modal/ModalLooker.tsx @@ -1,35 +1,12 @@ import { useTheme } from "@fiftyone/components"; -import { AbstractLooker } from "@fiftyone/looker"; -import { BaseState } from "@fiftyone/looker/src/state"; +import type { ImageLooker } from "@fiftyone/looker"; import * as fos from "@fiftyone/state"; -import { useEventHandler, useOnSelectLabel } from "@fiftyone/state"; -import React, { useEffect, useMemo, useRef, useState } from "react"; -import { useErrorHandler } from "react-error-boundary"; +import React, { useEffect, useMemo } from "react"; import { useRecoilCallback, useRecoilValue, useSetRecoilState } from "recoil"; -import { v4 as uuid } from "uuid"; -import { useModalContext } from "./hooks"; import { ImaVidLookerReact } from "./ImaVidLooker"; - -export const useLookerOptionsUpdate = () => { - return useRecoilCallback( - ({ snapshot, set }) => - async (update: object, updater?: (updated: {}) => void) => { - const currentOptions = await snapshot.getPromise( - fos.savedLookerOptions - ); - - const panels = await snapshot.getPromise(fos.lookerPanels); - const updated = { - ...currentOptions, - ...update, - showJSON: panels.json.isOpen, - showHelp: panels.help.isOpen, - }; - set(fos.savedLookerOptions, updated); - if (updater) updater(updated); - } - ); -}; +import { VideoLookerReact } from "./VideoLooker"; +import { useModalContext } from "./hooks"; +import useLooker from "./use-looker"; export const useShowOverlays = () => { return useRecoilCallback(({ set }) => async (event: CustomEvent) => { @@ -47,137 +24,40 @@ export const useClearSelectedLabels = () => { }; interface LookerProps { - sample?: fos.ModalSample; - onClick?: React.MouseEventHandler; + sample: fos.ModalSample; } -const ModalLookerNoTimeline = React.memo( - ({ sample: sampleDataWithExtraParams }: LookerProps) => { - const [id] = useState(() => uuid()); - const colorScheme = useRecoilValue(fos.colorScheme); - - const { sample } = sampleDataWithExtraParams; - - const theme = useTheme(); - const initialRef = useRef(true); - const lookerOptions = fos.useLookerOptions(true); - const [reset, setReset] = useState(false); - const selectedMediaField = useRecoilValue(fos.selectedMediaField(true)); - const setModalLooker = useSetRecoilState(fos.modalLooker); - - const createLooker = fos.useCreateLooker(true, false, { - ...lookerOptions, - }); - - const { setActiveLookerRef } = useModalContext(); - - const looker = React.useMemo( - () => createLooker.current(sampleDataWithExtraParams), - [reset, createLooker, selectedMediaField] - ) as AbstractLooker; - - useEffect(() => { - setModalLooker(looker); - }, [looker]); - - useEffect(() => { - if (looker) { - setActiveLookerRef(looker as fos.Lookers); - } - }, [looker]); - - useEffect(() => { - !initialRef.current && looker.updateOptions(lookerOptions); - }, [lookerOptions]); - - useEffect(() => { - !initialRef.current && looker.updateSample(sample); - }, [sample, colorScheme]); - - useEffect(() => { - return () => looker?.destroy(); - }, [looker]); - - const handleError = useErrorHandler(); +const ModalLookerNoTimeline = React.memo((props: LookerProps) => { + const { id, looker, ref } = useLooker(props); + const theme = useTheme(); + const setModalLooker = useSetRecoilState(fos.modalLooker); - const updateLookerOptions = useLookerOptionsUpdate(); - useEventHandler(looker, "options", (e) => updateLookerOptions(e.detail)); - useEventHandler(looker, "showOverlays", useShowOverlays()); - useEventHandler(looker, "reset", () => { - setReset((c) => !c); - }); + const { setActiveLookerRef } = useModalContext(); - const jsonPanel = fos.useJSONPanel(); - const helpPanel = fos.useHelpPanel(); + useEffect(() => { + setModalLooker(looker); + }, [looker, setModalLooker]); - useEventHandler(looker, "select", useOnSelectLabel()); - useEventHandler(looker, "error", (event) => handleError(event.detail)); - useEventHandler( - looker, - "panels", - async ({ detail: { showJSON, showHelp, SHORTCUTS } }) => { - if (showJSON) { - jsonPanel[showJSON](sample); - } - if (showHelp) { - if (showHelp == "close") { - helpPanel.close(); - } else { - helpPanel[showHelp](shortcutToHelpItems(SHORTCUTS)); - } - } - - updateLookerOptions({}, (updatedOptions) => - looker.updateOptions(updatedOptions) - ); - } - ); - - useEffect(() => { - initialRef.current = false; - }, []); - - useEffect(() => { - looker.attach(id); - }, [looker, id]); - - useEventHandler(looker, "clear", useClearSelectedLabels()); - - const hoveredSample = useRecoilValue(fos.hoveredSample); - - useEffect(() => { - const hoveredSampleId = hoveredSample?._id; - looker.updater((state) => ({ - ...state, - shouldHandleKeyEvents: hoveredSampleId === sample._id, - options: { - ...state.options, - }, - })); - }, [hoveredSample, sample, looker]); - - const ref = useRef(null); - useEffect(() => { - ref.current?.dispatchEvent( - new CustomEvent(`looker-attached`, { bubbles: true }) - ); - }, [ref]); - - return ( -
- ); - } -); + useEffect(() => { + if (looker) { + setActiveLookerRef(looker as fos.Lookers); + } + }, [looker, setActiveLookerRef]); + + return ( +
+ ); +}); export const ModalLooker = React.memo( ({ sample: propsSampleData }: LookerProps) => { @@ -197,21 +77,16 @@ export const ModalLooker = React.memo( const shouldRenderImavid = useRecoilValue( fos.shouldRenderImaVidLooker(true) ); + const video = useRecoilValue(fos.isVideoDataset); if (shouldRenderImavid) { return ; } + if (video) { + return ; + } + return ; } ); - -export function shortcutToHelpItems(SHORTCUTS) { - return Object.values( - Object.values(SHORTCUTS).reduce((acc, v) => { - acc[v.shortcut] = v; - - return acc; - }, {}) - ); -} diff --git a/app/packages/core/src/components/Modal/VideoLooker.tsx b/app/packages/core/src/components/Modal/VideoLooker.tsx new file mode 100644 index 0000000000..9c8f9b0cd2 --- /dev/null +++ b/app/packages/core/src/components/Modal/VideoLooker.tsx @@ -0,0 +1,80 @@ +import { useTheme } from "@fiftyone/components"; +import type { VideoLooker } from "@fiftyone/looker"; +import { getFrameNumber } from "@fiftyone/looker"; +import { + useCreateTimeline, + useDefaultTimelineNameImperative, + useTimeline, +} from "@fiftyone/playback"; +import * as fos from "@fiftyone/state"; +import React, { useEffect, useMemo, useState } from "react"; +import useLooker from "./use-looker"; + +interface VideoLookerReactProps { + sample: fos.ModalSample; +} + +export const VideoLookerReact = (props: VideoLookerReactProps) => { + const theme = useTheme(); + const { id, looker, sample } = useLooker(props); + const [totalFrames, setTotalFrames] = useState(); + const frameRate = useMemo(() => { + return sample.frameRate; + }, [sample]); + + useEffect(() => { + const load = () => { + const duration = looker.getVideo().duration; + setTotalFrames(getFrameNumber(duration, duration, frameRate)); + looker.removeEventListener("load", load); + }; + looker.addEventListener("load", load); + }, [frameRate, looker]); + + return ( + <> +
+ {totalFrames !== undefined && ( + + )} + + ); +}; + +const TimelineController = ({ + looker, + totalFrames, +}: { + looker: VideoLooker; + totalFrames: number; +}) => { + const { getName } = useDefaultTimelineNameImperative(); + const timelineName = React.useMemo(() => getName(), [getName]); + + useCreateTimeline({ + name: timelineName, + config: totalFrames + ? { + totalFrames, + loop: true, + } + : undefined, + optOutOfAnimation: true, + }); + + const { pause, play } = useTimeline(timelineName); + + fos.useEventHandler(looker, "pause", pause); + fos.useEventHandler(looker, "play", play); + + return null; +}; diff --git a/app/packages/core/src/components/Modal/hooks.ts b/app/packages/core/src/components/Modal/hooks.ts index 700955dd47..2cdb6d6310 100644 --- a/app/packages/core/src/components/Modal/hooks.ts +++ b/app/packages/core/src/components/Modal/hooks.ts @@ -19,6 +19,27 @@ export const useLookerHelpers = () => { }; }; +export const useLookerOptionsUpdate = () => { + return useRecoilCallback( + ({ snapshot, set }) => + async (update: object, updater?: (updated: {}) => void) => { + const currentOptions = await snapshot.getPromise( + fos.savedLookerOptions + ); + + const panels = await snapshot.getPromise(fos.lookerPanels); + const updated = { + ...currentOptions, + ...update, + showJSON: panels.json.isOpen, + showHelp: panels.help.isOpen, + }; + set(fos.savedLookerOptions, updated); + if (updater) updater(updated); + } + ); +}; + export const useInitializeImaVidSubscriptions = () => { const subscribeToImaVidStateChanges = useRecoilCallback( ({ set }) => diff --git a/app/packages/core/src/components/Modal/use-key-events.ts b/app/packages/core/src/components/Modal/use-key-events.ts new file mode 100644 index 0000000000..49a4ce313b --- /dev/null +++ b/app/packages/core/src/components/Modal/use-key-events.ts @@ -0,0 +1,40 @@ +import type { Lookers } from "@fiftyone/state"; +import { hoveredSample } from "@fiftyone/state"; +import type { MutableRefObject } from "react"; +import { useEffect, useRef } from "react"; +import { selector, useRecoilValue } from "recoil"; + +export const hoveredSampleId = selector({ + key: "hoveredSampleId", + get: ({ get }) => { + return get(hoveredSample)?._id; + }, +}); + +export default function ( + ref: MutableRefObject, + id: string, + looker: Lookers +) { + const hoveredId = useRecoilValue(hoveredSampleId); + const ready = useRef(false); + + useEffect(() => { + if (ref.current) { + // initial call should wait for load event + const update = () => { + looker.updateOptions({ + shouldHandleKeyEvents: id === hoveredId, + }); + ready.current = true; + + looker.removeEventListener("load", update); + }; + looker.addEventListener("load", update); + } else if (ready.current) { + looker.updateOptions({ + shouldHandleKeyEvents: id === hoveredId, + }); + } + }, [hoveredId, id, looker, ref]); +} diff --git a/app/packages/core/src/components/Modal/use-looker.ts b/app/packages/core/src/components/Modal/use-looker.ts new file mode 100644 index 0000000000..fa44065f95 --- /dev/null +++ b/app/packages/core/src/components/Modal/use-looker.ts @@ -0,0 +1,109 @@ +import * as fos from "@fiftyone/state"; +import React, { useEffect, useRef, useState } from "react"; +import { useErrorHandler } from "react-error-boundary"; +import { useRecoilValue } from "recoil"; +import { v4 as uuid } from "uuid"; +import { useClearSelectedLabels, useShowOverlays } from "./ModalLooker"; +import { useLookerOptionsUpdate } from "./hooks"; +import useKeyEvents from "./use-key-events"; +import { shortcutToHelpItems } from "./utils"; + +const CLOSE = "close"; + +function useLooker({ + sample, +}: { + sample: fos.ModalSample; +}) { + const [id] = useState(() => uuid()); + const initialRef = useRef(true); + const ref = useRef(null); + const [reset, setReset] = useState(false); + const lookerOptions = fos.useLookerOptions(true); + const createLooker = fos.useCreateLooker( + true, + false, + lookerOptions, + undefined, + true + ); + const selectedMediaField = useRecoilValue(fos.selectedMediaField(true)); + const colorScheme = useRecoilValue(fos.colorScheme); + const looker = React.useMemo(() => { + /** start refreshers */ + reset; + selectedMediaField; + /** end refreshers */ + + return createLooker.current(sample); + }, [createLooker, reset, sample, selectedMediaField]) as L; + const handleError = useErrorHandler(); + const updateLookerOptions = useLookerOptionsUpdate(); + + fos.useEventHandler(looker, "clear", useClearSelectedLabels()); + fos.useEventHandler(looker, "error", (event) => handleError(event.detail)); + fos.useEventHandler(looker, "options", (e) => updateLookerOptions(e.detail)); + fos.useEventHandler(looker, "reset", () => setReset((c) => !c)); + fos.useEventHandler(looker, "select", fos.useOnSelectLabel()); + fos.useEventHandler(looker, "showOverlays", useShowOverlays()); + + useEffect(() => { + !initialRef.current && looker.updateOptions(lookerOptions); + }, [looker, lookerOptions]); + + useEffect(() => { + /** start refreshers */ + colorScheme; + /** end refreshers */ + + !initialRef.current && looker.updateSample(sample); + }, [colorScheme, looker, sample]); + + useEffect(() => { + initialRef.current = false; + }, []); + + useEffect(() => { + ref.current?.dispatchEvent( + new CustomEvent("looker-attached", { bubbles: true }) + ); + }, []); + + useEffect(() => { + looker.attach(id); + }, [looker, id]); + + useEffect(() => { + return () => looker?.destroy(); + }, [looker]); + + const jsonPanel = fos.useJSONPanel(); + const helpPanel = fos.useHelpPanel(); + + fos.useEventHandler( + looker, + "panels", + async ({ detail: { showJSON, showHelp, SHORTCUTS } }) => { + if (showJSON) { + jsonPanel[showJSON](sample); + } + if (showHelp) { + if (showHelp === CLOSE) { + helpPanel.close(); + } else { + helpPanel[showHelp](shortcutToHelpItems(SHORTCUTS)); + } + } + + updateLookerOptions({}, (updatedOptions) => + looker.updateOptions(updatedOptions) + ); + } + ); + + useKeyEvents(initialRef, sample.sample._id, looker); + + return { id, looker, ref, sample, updateLookerOptions }; +} + +export default useLooker; diff --git a/app/packages/core/src/components/Modal/utils.ts b/app/packages/core/src/components/Modal/utils.ts new file mode 100644 index 0000000000..2461570f9a --- /dev/null +++ b/app/packages/core/src/components/Modal/utils.ts @@ -0,0 +1,7 @@ +export function shortcutToHelpItems(SHORTCUTS) { + const result = {}; + for (const k of SHORTCUTS) { + result[SHORTCUTS[k].shortcut] = SHORTCUTS[k]; + } + return Object.values(result); +} diff --git a/app/packages/looker/package.json b/app/packages/looker/package.json index b47bb84fd6..8ac5a0b945 100644 --- a/app/packages/looker/package.json +++ b/app/packages/looker/package.json @@ -39,5 +39,8 @@ "typescript": "^4.7.4", "typescript-plugin-css-modules": "^5.1.0", "vite": "^5.2.14" + }, + "peerDependencies": { + "jotai": "*" } } diff --git a/app/packages/looker/src/elements/base.ts b/app/packages/looker/src/elements/base.ts index e872a800c5..5ce470ef87 100644 --- a/app/packages/looker/src/elements/base.ts +++ b/app/packages/looker/src/elements/base.ts @@ -64,8 +64,7 @@ export abstract class BaseElement< for (const [eventType, handler] of Object.entries(this.getEvents(config))) { this.events[eventType] = (event) => handler({ event, update, dispatchEvent }); - this.element && - this.element.addEventListener(eventType, this.events[eventType]); + this.element?.addEventListener(eventType, this.events[eventType]); } } applyChildren(children: BaseElement[]) { diff --git a/app/packages/looker/src/elements/common/looker.ts b/app/packages/looker/src/elements/common/looker.ts index 8910d90cd0..b483b178f2 100644 --- a/app/packages/looker/src/elements/common/looker.ts +++ b/app/packages/looker/src/elements/common/looker.ts @@ -3,8 +3,10 @@ */ import { SELECTION_TEXT } from "../../constants"; -import { BaseState, Control, ControlEventKeyType } from "../../state"; -import { BaseElement, Events } from "../base"; +import type { BaseState, Control } from "../../state"; +import { ControlEventKeyType } from "../../state"; +import type { Events } from "../base"; +import { BaseElement } from "../base"; import { looker, lookerError, lookerHighlight } from "./looker.module.css"; @@ -24,7 +26,11 @@ export class LookerElement extends BaseElement< const e = event as KeyboardEvent; update((state) => { - const { SHORTCUTS, error, shouldHandleKeyEvents } = state; + const { + SHORTCUTS, + error, + options: { shouldHandleKeyEvents }, + } = state; if (!error && e.key in SHORTCUTS) { const matchedControl = SHORTCUTS[e.key] as Control; const enabled = @@ -43,7 +49,7 @@ export class LookerElement extends BaseElement< } const e = event as KeyboardEvent; - update(({ SHORTCUTS, error, shouldHandleKeyEvents }) => { + update(({ SHORTCUTS, error, options: { shouldHandleKeyEvents } }) => { if (!error && e.key in SHORTCUTS) { const matchedControl = SHORTCUTS[e.key] as Control; diff --git a/app/packages/looker/src/index.ts b/app/packages/looker/src/index.ts index 333140735d..889ec3ff73 100644 --- a/app/packages/looker/src/index.ts +++ b/app/packages/looker/src/index.ts @@ -3,7 +3,7 @@ */ export { createColorGenerator, getRGB } from "@fiftyone/utilities"; -export { freeVideos } from "./elements/util"; +export { freeVideos, getFrameNumber } from "./elements/util"; export * from "./lookers"; export type { PointInfo } from "./overlays"; export type { diff --git a/app/packages/looker/src/lookers/imavid/index.ts b/app/packages/looker/src/lookers/imavid/index.ts index ae3d60d1d2..cd302395e4 100644 --- a/app/packages/looker/src/lookers/imavid/index.ts +++ b/app/packages/looker/src/lookers/imavid/index.ts @@ -15,6 +15,8 @@ import { IMAVID_PLAYBACK_RATE_LOCAL_STORAGE_KEY, } from "./constants"; +export { BUFFERING_PAUSE_TIMEOUT } from "./constants"; + const DEFAULT_PAN = 0; const DEFAULT_SCALE = 1; const FIRST_FRAME = 1; diff --git a/app/packages/looker/src/lookers/video.ts b/app/packages/looker/src/lookers/video.ts index e07771fe7e..0307594409 100644 --- a/app/packages/looker/src/lookers/video.ts +++ b/app/packages/looker/src/lookers/video.ts @@ -20,7 +20,9 @@ import { } from "../state"; import { addToBuffers, createWorker, removeFromBuffers } from "../util"; +import { setFrameNumberAtom } from "@fiftyone/playback"; import { Schema } from "@fiftyone/utilities"; +import { getDefaultStore } from "jotai"; import { LRUCache } from "lru-cache"; import { CHUNK_SIZE, MAX_FRAME_CACHE_SIZE_BYTES } from "../constants"; import { getFrameNumber } from "../elements/util"; @@ -525,6 +527,13 @@ export class VideoLooker extends AbstractLooker { this.state.setZoom = false; } + if (this.state.config.enableTimeline) { + getDefaultStore().set(setFrameNumberAtom, { + name: `timeline-${this.state.config.sampleId}`, + newFrameNumber: this.state.frameNumber, + }); + } + return super.postProcess(); } @@ -546,6 +555,10 @@ export class VideoLooker extends AbstractLooker { this.setReader(); } + getVideo() { + return this.lookerElement.children[0].element as HTMLVideoElement; + } + private hasFrame(frameNumber: number) { return ( this.frames.has(frameNumber) && diff --git a/app/packages/looker/src/state.ts b/app/packages/looker/src/state.ts index 64dfe21278..84281f154a 100644 --- a/app/packages/looker/src/state.ts +++ b/app/packages/looker/src/state.ts @@ -175,6 +175,7 @@ interface BaseOptions { smoothMasks: boolean; zoomPad: number; selected: boolean; + shouldHandleKeyEvents?: boolean; inSelectionMode: boolean; timeZone: string; mimetype: string; @@ -218,6 +219,7 @@ export interface FrameConfig extends BaseConfig { export type ImageConfig = BaseConfig; export interface VideoConfig extends BaseConfig { + enableTimeline: boolean; frameRate: number; support?: [number, number]; } @@ -297,7 +299,6 @@ export interface BaseState { showOptions: boolean; config: BaseConfig; options: BaseOptions; - shouldHandleKeyEvents: boolean; scale: number; pan: Coordinates; panning: boolean; @@ -458,6 +459,7 @@ export const DEFAULT_BASE_OPTIONS: BaseOptions = { pointFilter: (path: string, point: Point) => true, attributeVisibility: {}, mediaFallback: false, + shouldHandleKeyEvents: true, }; export const DEFAULT_FRAME_OPTIONS: FrameOptions = { diff --git a/app/packages/operators/src/built-in-operators.ts b/app/packages/operators/src/built-in-operators.ts index 6566f954cd..ae210641fd 100644 --- a/app/packages/operators/src/built-in-operators.ts +++ b/app/packages/operators/src/built-in-operators.ts @@ -1313,6 +1313,42 @@ export class SetGroupSlice extends Operator { } } +export class DisableQueryPerformance extends Operator { + _builtIn = true; + get config(): OperatorConfig { + return new OperatorConfig({ + name: "disable_query_performance", + label: "Disable query performance", + }); + } + + useHooks() { + const { disable } = fos.useQueryPerformance(); + return { disable }; + } + async execute({ hooks }: ExecutionContext) { + hooks.disable(); + } +} + +export class EnableQueryPerformance extends Operator { + _builtIn = true; + get config(): OperatorConfig { + return new OperatorConfig({ + name: "enable_query_performance", + label: "Enable query performance", + }); + } + + useHooks() { + const { enable } = fos.useQueryPerformance(); + return { enable }; + } + async execute({ hooks }: ExecutionContext) { + hooks.enable(); + } +} + export function registerBuiltInOperators() { try { _registerBuiltInOperator(CopyViewAsJSON); @@ -1362,6 +1398,8 @@ export function registerBuiltInOperators() { _registerBuiltInOperator(SetGroupSlice); _registerBuiltInOperator(SetPlayheadState); _registerBuiltInOperator(SetFrameNumber); + _registerBuiltInOperator(DisableQueryPerformance); + _registerBuiltInOperator(EnableQueryPerformance); } catch (e) { console.error("Error registering built-in operators"); console.error(e); diff --git a/app/packages/operators/src/operators.ts b/app/packages/operators/src/operators.ts index f11aa3d880..d16f175fba 100644 --- a/app/packages/operators/src/operators.ts +++ b/app/packages/operators/src/operators.ts @@ -1,5 +1,5 @@ import { AnalyticsInfo, usingAnalytics } from "@fiftyone/analytics"; -import { getFetchFunction, isNullish, ServerError } from "@fiftyone/utilities"; +import { ServerError, getFetchFunction, isNullish } from "@fiftyone/utilities"; import { CallbackInterface } from "recoil"; import { QueueItemStatus } from "./constants"; import * as types from "./types"; @@ -91,6 +91,7 @@ export type RawContext = { scope: string; }; groupSlice: string; + queryPerformance?: boolean; }; export class ExecutionContext { @@ -136,6 +137,10 @@ export class ExecutionContext { public get groupSlice(): any { return this._currentContext.groupSlice; } + public get queryPerformance(): boolean { + return Boolean(this._currentContext.queryPerformance); + } + getCurrentPanelId(): string | null { return this.params.panel_id || this.currentPanel?.id || null; } @@ -706,6 +711,7 @@ export async function executeOperatorWithContext( view: currentContext.view, view_name: currentContext.viewName, group_slice: currentContext.groupSlice, + query_performance: currentContext.queryPerformance, } ); result = serverResult.result; diff --git a/app/packages/operators/src/state.ts b/app/packages/operators/src/state.ts index 97e675c5d1..cf0f08cd14 100644 --- a/app/packages/operators/src/state.ts +++ b/app/packages/operators/src/state.ts @@ -94,6 +94,7 @@ const globalContextSelector = selector({ const viewName = get(fos.viewName); const extendedSelection = get(fos.extendedSelection); const groupSlice = get(fos.groupSlice); + const queryPerformance = typeof get(fos.lightningThreshold) === "number"; return { datasetName, @@ -105,6 +106,7 @@ const globalContextSelector = selector({ viewName, extendedSelection, groupSlice, + queryPerformance, }; }, }); @@ -145,6 +147,7 @@ const useExecutionContext = (operatorName, hooks = {}) => { viewName, extendedSelection, groupSlice, + queryPerformance, } = curCtx; const [analyticsInfo] = useAnalyticsInfo(); const ctx = useMemo(() => { @@ -162,6 +165,7 @@ const useExecutionContext = (operatorName, hooks = {}) => { extendedSelection, analyticsInfo, groupSlice, + queryPerformance, }, hooks ); @@ -177,6 +181,7 @@ const useExecutionContext = (operatorName, hooks = {}) => { viewName, currentSample, groupSlice, + queryPerformance, ]); return ctx; diff --git a/app/packages/plugins/src/externalize.ts b/app/packages/plugins/src/externalize.ts index 923070fe30..f8a0fd2691 100644 --- a/app/packages/plugins/src/externalize.ts +++ b/app/packages/plugins/src/externalize.ts @@ -3,6 +3,7 @@ import * as foo from "@fiftyone/operators"; import * as fos from "@fiftyone/state"; import * as fou from "@fiftyone/utilities"; import * as fosp from "@fiftyone/spaces"; +import * as fop from "@fiftyone/plugins"; import * as mui from "@mui/material"; import React from "react"; import ReactDOM from "react-dom"; @@ -19,6 +20,7 @@ declare global { __fou__: typeof fou; __foo__: typeof foo; __fosp__: typeof fosp; + __fop__: typeof fop; __mui__: typeof mui; __styled__: typeof styled; } @@ -36,5 +38,6 @@ if (typeof window !== "undefined") { window.__foo__ = foo; window.__fosp__ = fosp; window.__mui__ = mui; + window.__fop__ = fop; window.__styled__ = styled; } diff --git a/app/packages/state/src/hooks/index.ts b/app/packages/state/src/hooks/index.ts index 8a1d691dfa..e1f31864cc 100644 --- a/app/packages/state/src/hooks/index.ts +++ b/app/packages/state/src/hooks/index.ts @@ -24,6 +24,7 @@ export { default as useLookerStore } from "./useLookerStore"; export { default as useNotification } from "./useNotification"; export * from "./useOnSelectLabel"; export { default as usePanel } from "./usePanel"; +export { default as useQueryPerformance } from "./useQueryPerformance"; export { default as useRefresh } from "./useRefresh"; export { default as useReset } from "./useReset"; export { default as useResetExtendedSelection } from "./useResetExtendedSelection"; diff --git a/app/packages/state/src/hooks/useCreateLooker.ts b/app/packages/state/src/hooks/useCreateLooker.ts index 06e04eb652..1fc1b748e3 100644 --- a/app/packages/state/src/hooks/useCreateLooker.ts +++ b/app/packages/state/src/hooks/useCreateLooker.ts @@ -9,7 +9,7 @@ import { } from "@fiftyone/looker"; import { ImaVidFramesController } from "@fiftyone/looker/src/lookers/imavid/controller"; import { ImaVidFramesControllerStore } from "@fiftyone/looker/src/lookers/imavid/store"; -import { BaseState, ImaVidConfig } from "@fiftyone/looker/src/state"; +import type { BaseState, ImaVidConfig } from "@fiftyone/looker/src/state"; import { EMBEDDED_DOCUMENT_FIELD, LIST_FIELD, @@ -36,7 +36,8 @@ export default >( isModal: boolean, thumbnail: boolean, options: Omit[0], "selected">, - highlight?: (sample: Sample) => boolean + highlight?: (sample: Sample) => boolean, + enableTimeline?: boolean ) => { const environment = useRelayEnvironment(); const selected = useRecoilValue(selectedSamples); @@ -112,6 +113,7 @@ export default >( } let config: ConstructorParameters[1] = { + enableTimeline, fieldSchema: { ...fieldSchema, frames: { @@ -132,6 +134,7 @@ export default >( mediaField, thumbnail, view, + shouldHandleKeyEvents: isModal, }; let sampleMediaFilePath = urls[mediaField]; diff --git a/app/packages/state/src/hooks/useQueryPerformance.ts b/app/packages/state/src/hooks/useQueryPerformance.ts new file mode 100644 index 0000000000..091dd85232 --- /dev/null +++ b/app/packages/state/src/hooks/useQueryPerformance.ts @@ -0,0 +1,41 @@ +import { useMemo } from "react"; +import { useRecoilCallback } from "recoil"; +import { + datasetSampleCount, + lightningThreshold, + lightningThresholdConfig, +} from "../recoil"; + +export default function () { + const disable = useRecoilCallback( + ({ set }) => + () => { + set(lightningThreshold, null); + }, + [] + ); + + const enable = useRecoilCallback( + ({ set, snapshot }) => + async (threshold?: number) => { + let setting = threshold; + + if (setting === undefined) { + setting = + (await snapshot.getPromise(lightningThresholdConfig)) ?? + (await snapshot.getPromise(datasetSampleCount)); + } + + set(lightningThreshold, setting); + }, + [] + ); + + return useMemo( + () => ({ + disable, + enable, + }), + [disable, enable] + ); +} diff --git a/app/packages/state/src/recoil/modal.ts b/app/packages/state/src/recoil/modal.ts index 2a79313a37..52baab87c2 100644 --- a/app/packages/state/src/recoil/modal.ts +++ b/app/packages/state/src/recoil/modal.ts @@ -1,13 +1,9 @@ -import { - AbstractLooker, - BaseState, - PointInfo, - type Sample, -} from "@fiftyone/looker"; +import { PointInfo, type Sample } from "@fiftyone/looker"; import { mainSample, mainSampleQuery } from "@fiftyone/relay"; import { atom, selector } from "recoil"; import { graphQLSelector } from "recoil-relay"; import { VariablesOf } from "relay-runtime"; +import type { Lookers } from "../hooks"; import { ComputeCoordinatesReturnType } from "../hooks/useTooltip"; import { ModalSelector, sessionAtom } from "../session"; import { ResponseFrom } from "../utils"; @@ -27,7 +23,7 @@ import { datasetName } from "./selectors"; import { mapSampleResponse } from "./utils"; import { view } from "./view"; -export const modalLooker = atom | null>({ +export const modalLooker = atom({ key: "modalLooker", default: null, dangerouslyAllowMutability: true, @@ -73,7 +69,7 @@ export const currentSampleId = selector({ ? get(pinned3DSample).id : get(nullableModalSampleId); - if (id && id.endsWith("-modal")) { + if (id?.endsWith("-modal")) { return id.replace("-modal", ""); } return id; diff --git a/app/packages/state/src/utils.ts b/app/packages/state/src/utils.ts index 1bb9e0c99a..4b61bc53f4 100644 --- a/app/packages/state/src/utils.ts +++ b/app/packages/state/src/utils.ts @@ -18,7 +18,7 @@ import { RecordSource, Store, } from "relay-runtime"; -import { State } from "./recoil"; +import { ModalSample, State } from "./recoil"; export const deferrer = (initialized: MutableRefObject) => @@ -105,6 +105,7 @@ export const getStandardizedUrls = ( urls: | readonly { readonly field: string; readonly url: string }[] | { [field: string]: string } + | ModalSample["urls"] ) => { if (!Array.isArray(urls)) { return urls; diff --git a/app/packages/utilities/src/schema.test.ts b/app/packages/utilities/src/schema.test.ts index f1d9624663..506c05ee97 100644 --- a/app/packages/utilities/src/schema.test.ts +++ b/app/packages/utilities/src/schema.test.ts @@ -19,7 +19,7 @@ const SCHEMA: schema.Schema = { "fiftyone.core.odm.embedded_document.DynamicEmbeddedDocument", ftype: "fiftyone.core.fields.EmbeddedDocumentField", info: {}, - name: "top", + name: "embedded", path: "embedded", subfield: null, fields: { @@ -30,7 +30,7 @@ const SCHEMA: schema.Schema = { ftype: "fiftyone.core.fields.EmbeddedDocumentField", info: {}, name: "field", - path: "field", + path: "embedded.field", subfield: null, }, }, @@ -42,7 +42,7 @@ const SCHEMA: schema.Schema = { "fiftyone.core.odm.embedded_document.DynamicEmbeddedDocument", ftype: "fiftyone.core.fields.EmbeddedDocumentField", info: {}, - name: "top", + name: "embeddedWithDbFields", path: "embeddedWithDbFields", subfield: null, fields: { @@ -54,7 +54,7 @@ const SCHEMA: schema.Schema = { ftype: "fiftyone.core.fields.EmbeddedDocumentField", info: {}, name: "sample_id", - path: "sample_id", + path: "embeddedWithDbFields.sample_id", subfield: null, }, }, @@ -62,7 +62,7 @@ const SCHEMA: schema.Schema = { }; describe("schema", () => { - describe("getCls ", () => { + describe("getCls", () => { it("should get top level cls", () => { expect(schema.getCls("top", SCHEMA)).toBe("TopLabel"); }); @@ -79,7 +79,7 @@ describe("schema", () => { }); }); - describe("getFieldInfo ", () => { + describe("getFieldInfo", () => { it("should get top level field info", () => { expect(schema.getFieldInfo("top", SCHEMA)).toEqual({ ...SCHEMA.top, @@ -89,7 +89,7 @@ describe("schema", () => { it("should get embedded field info", () => { expect(schema.getFieldInfo("embedded.field", SCHEMA)).toEqual({ - ...SCHEMA.embedded.fields.field, + ...SCHEMA.embedded.fields!.field, pathWithDbField: "", }); }); @@ -109,4 +109,69 @@ describe("schema", () => { expect(field?.pathWithDbField).toBe("embeddedWithDbFields._sample_id"); }); }); + + describe("getFieldsWithEmbeddedDocType", () => { + it("should get all fields with embeddedDocType at top level", () => { + expect( + schema.getFieldsWithEmbeddedDocType( + SCHEMA, + "fiftyone.core.labels.TopLabel" + ) + ).toEqual([SCHEMA.top]); + }); + + it("should get all fields with embeddedDocType in nested fields", () => { + expect( + schema.getFieldsWithEmbeddedDocType( + SCHEMA, + "fiftyone.core.labels.EmbeddedLabel" + ) + ).toEqual([ + SCHEMA.embedded.fields!.field, + SCHEMA.embeddedWithDbFields.fields!.sample_id, + ]); + }); + + it("should return empty array if embeddedDocType does not exist", () => { + expect( + schema.getFieldsWithEmbeddedDocType(SCHEMA, "nonexistentDocType") + ).toEqual([]); + }); + + it("should return empty array for empty schema", () => { + expect(schema.getFieldsWithEmbeddedDocType({}, "anyDocType")).toEqual([]); + }); + }); + + describe("doesSchemaContainEmbeddedDocType", () => { + it("should return true if embeddedDocType exists at top level", () => { + expect( + schema.doesSchemaContainEmbeddedDocType( + SCHEMA, + "fiftyone.core.labels.TopLabel" + ) + ).toBe(true); + }); + + it("should return true if embeddedDocType exists in nested fields", () => { + expect( + schema.doesSchemaContainEmbeddedDocType( + SCHEMA, + "fiftyone.core.labels.EmbeddedLabel" + ) + ).toBe(true); + }); + + it("should return false if embeddedDocType does not exist", () => { + expect( + schema.doesSchemaContainEmbeddedDocType(SCHEMA, "nonexistentDocType") + ).toBe(false); + }); + + it("should return false for empty schema", () => { + expect(schema.doesSchemaContainEmbeddedDocType({}, "anyDocType")).toBe( + false + ); + }); + }); }); diff --git a/app/packages/utilities/src/schema.ts b/app/packages/utilities/src/schema.ts index d65d2d9c7f..188ab1eb37 100644 --- a/app/packages/utilities/src/schema.ts +++ b/app/packages/utilities/src/schema.ts @@ -51,3 +51,43 @@ export function getCls(fieldPath: string, schema: Schema): string | undefined { return field.embeddedDocType.split(".").slice(-1)[0]; } + +export function getFieldsWithEmbeddedDocType( + schema: Schema, + embeddedDocType: string +): Field[] { + const result: Field[] = []; + + function recurse(schema: Schema) { + for (const field of Object.values(schema ?? {})) { + if (field.embeddedDocType === embeddedDocType) { + result.push(field); + } + if (field.fields) { + recurse(field.fields); + } + } + } + + recurse(schema); + return result; +} + +export function doesSchemaContainEmbeddedDocType( + schema: Schema, + embeddedDocType: string +): boolean { + function recurse(schema: Schema): boolean { + return Object.values(schema ?? {}).some((field) => { + if (field.embeddedDocType === embeddedDocType) { + return true; + } + if (field.fields) { + return recurse(field.fields); + } + return false; + }); + } + + return recurse(schema); +} diff --git a/app/yarn.lock b/app/yarn.lock index dc9917db4b..86e96c90f8 100644 --- a/app/yarn.lock +++ b/app/yarn.lock @@ -1885,6 +1885,8 @@ __metadata: typescript-plugin-css-modules: ^5.1.0 uuid: ^8.3.2 vite: ^5.2.14 + peerDependencies: + jotai: "*" languageName: unknown linkType: soft diff --git a/docs/source/integrations/coco.rst b/docs/source/integrations/coco.rst index 7ca3b0e890..2b910952ef 100644 --- a/docs/source/integrations/coco.rst +++ b/docs/source/integrations/coco.rst @@ -192,18 +192,14 @@ file containing COCO-formatted labels to work with: dataset = foz.load_zoo_dataset("quickstart") - # Classes list - classes = dataset.distinct("ground_truth.detections.label") - # The directory in which the dataset's images are stored IMAGES_DIR = os.path.dirname(dataset.first().filepath) # Export some labels in COCO format - dataset.take(5).export( + dataset.take(5, seed=51).export( dataset_type=fo.types.COCODetectionDataset, label_field="ground_truth", labels_path="/tmp/coco.json", - classes=classes, ) Now we have a ``/tmp/coco.json`` file on disk containing COCO labels @@ -220,7 +216,7 @@ corresponding to the images in ``IMAGES_DIR``: "licenses": [], "categories": [ { - "id": 0, + "id": 1, "name": "airplane", "supercategory": null }, @@ -229,9 +225,9 @@ corresponding to the images in ``IMAGES_DIR``: "images": [ { "id": 1, - "file_name": "001631.jpg", - "height": 612, - "width": 612, + "file_name": "003486.jpg", + "height": 427, + "width": 640, "license": null, "coco_url": null }, @@ -241,14 +237,14 @@ corresponding to the images in ``IMAGES_DIR``: { "id": 1, "image_id": 1, - "category_id": 9, + "category_id": 1, "bbox": [ - 92.14, - 220.04, - 519.86, - 61.89000000000001 + 34.34, + 147.46, + 492.69, + 192.36 ], - "area": 32174.135400000006, + "area": 94773.8484, "iscrowd": 0 }, ... @@ -271,8 +267,9 @@ dataset: include_id=True, ) - # Verify that the class list for our dataset was imported - print(coco_dataset.default_classes) # ['airplane', 'apple', ...] + # COCO categories are also imported + print(coco_dataset.info["categories"]) + # [{'id': 1, 'name': 'airplane', 'supercategory': None}, ...] print(coco_dataset) @@ -319,16 +316,16 @@ to add them to your dataset as follows: # # Mock COCO predictions, where: # - `image_id` corresponds to the `coco_id` field of `coco_dataset` - # - `category_id` corresponds to classes in `coco_dataset.default_classes` + # - `category_id` corresponds to `coco_dataset.info["categories"]` # predictions = [ - {"image_id": 1, "category_id": 18, "bbox": [258, 41, 348, 243], "score": 0.87}, - {"image_id": 2, "category_id": 11, "bbox": [61, 22, 504, 609], "score": 0.95}, + {"image_id": 1, "category_id": 2, "bbox": [258, 41, 348, 243], "score": 0.87}, + {"image_id": 2, "category_id": 4, "bbox": [61, 22, 504, 609], "score": 0.95}, ] + categories = coco_dataset.info["categories"] # Add COCO predictions to `predictions` field of dataset - classes = coco_dataset.default_classes - fouc.add_coco_labels(coco_dataset, "predictions", predictions, classes) + fouc.add_coco_labels(coco_dataset, "predictions", predictions, categories) # Verify that predictions were added to two images print(coco_dataset.count("predictions")) # 2 diff --git a/docs/source/user_guide/dataset_creation/datasets.rst b/docs/source/user_guide/dataset_creation/datasets.rst index c25550a191..ac4085bb3b 100644 --- a/docs/source/user_guide/dataset_creation/datasets.rst +++ b/docs/source/user_guide/dataset_creation/datasets.rst @@ -1499,9 +1499,8 @@ where `labels.json` is a JSON file in the following format: ... ], "categories": [ - ... { - "id": 2, + "id": 1, "name": "cat", "supercategory": "animal", "keypoints": ["nose", "head", ...], @@ -1524,7 +1523,7 @@ where `labels.json` is a JSON file in the following format: { "id": 1, "image_id": 1, - "category_id": 2, + "category_id": 1, "bbox": [260, 177, 231, 199], "segmentation": [...], "keypoints": [224, 226, 2, ...], diff --git a/docs/source/user_guide/export_datasets.rst b/docs/source/user_guide/export_datasets.rst index 293672544a..810601036b 100644 --- a/docs/source/user_guide/export_datasets.rst +++ b/docs/source/user_guide/export_datasets.rst @@ -1646,9 +1646,8 @@ where `labels.json` is a JSON file in the following format: }, "licenses": [], "categories": [ - ... { - "id": 2, + "id": 1, "name": "cat", "supercategory": "animal" }, @@ -1669,7 +1668,7 @@ where `labels.json` is a JSON file in the following format: { "id": 1, "image_id": 1, - "category_id": 2, + "category_id": 1, "bbox": [260, 177, 231, 199], "segmentation": [...], "score": 0.95, diff --git a/fiftyone/operators/executor.py b/fiftyone/operators/executor.py index 8b5da0d5ea..e0e6ac784c 100644 --- a/fiftyone/operators/executor.py +++ b/fiftyone/operators/executor.py @@ -713,6 +713,11 @@ def group_slice(self): """The current group slice of the view (if any).""" return self.request_params.get("group_slice", None) + @property + def query_performance(self): + """Whether query performance is enabled.""" + return self.request_params.get("query_performance", None) + def prompt( self, operator_uri, diff --git a/fiftyone/server/metadata.py b/fiftyone/server/metadata.py index 294d787782..6992447e4a 100644 --- a/fiftyone/server/metadata.py +++ b/fiftyone/server/metadata.py @@ -24,6 +24,7 @@ import fiftyone.core.labels as fol from fiftyone.core.collections import SampleCollection from fiftyone.utils.utils3d import OrthographicProjectionMetadata +from fiftyone.utils.rerun import RrdFile import fiftyone.core.media as fom @@ -33,6 +34,7 @@ fol.Heatmap: "map_path", fol.Segmentation: "mask_path", OrthographicProjectionMetadata: "filepath", + RrdFile: "filepath", } _FFPROBE_BINARY_PATH = shutil.which("ffprobe") diff --git a/fiftyone/utils/coco.py b/fiftyone/utils/coco.py index 76a4fd494b..b2c5a730d9 100644 --- a/fiftyone/utils/coco.py +++ b/fiftyone/utils/coco.py @@ -45,7 +45,7 @@ def add_coco_labels( sample_collection, label_field, labels_or_path, - classes, + categories, label_type="detections", coco_id_field=None, include_annotation_id=False, @@ -68,7 +68,7 @@ def add_coco_labels( { "id": 1, "image_id": 1, - "category_id": 2, + "category_id": 1, "bbox": [260, 177, 231, 199], # optional @@ -88,7 +88,7 @@ def add_coco_labels( { "id": 1, "image_id": 1, - "category_id": 2, + "category_id": 1, "bbox": [260, 177, 231, 199], "segmentation": [...], @@ -109,7 +109,7 @@ def add_coco_labels( { "id": 1, "image_id": 1, - "category_id": 2, + "category_id": 1, "keypoints": [224, 226, 2, ...], "num_keypoints": 10, @@ -129,8 +129,14 @@ def add_coco_labels( will be created if necessary labels_or_path: a list of COCO annotations or the path to a JSON file containing such data on disk - classes: the list of class label strings or a dict mapping class IDs to - class labels + categories: can be any of the following: + + - a list of category dicts in the format of + :meth:`parse_coco_categories` specifying the classes and their + category IDs + - a dict mapping class IDs to class labels + - a list of class labels whose 1-based ordering is assumed to + correspond to the category IDs in the provided COCO labels label_type ("detections"): the type of labels to load. Supported values are ``("detections", "segmentations", "keypoints")`` coco_id_field (None): this parameter determines how to map the @@ -195,10 +201,14 @@ class labels view.compute_metadata() widths, heights = view.values(["metadata.width", "metadata.height"]) - if isinstance(classes, dict): - classes_map = classes + if isinstance(categories, dict): + classes_map = categories + elif not categories: + classes_map = {} + elif isinstance(categories[0], dict): + classes_map = {c["id"]: c["name"] for c in categories} else: - classes_map = {i: label for i, label in enumerate(classes)} + classes_map = {i: label for i, label in enumerate(categories, 1)} labels = [] for _coco_objects, width, height in zip(coco_objects, widths, heights): @@ -563,15 +573,11 @@ def setup(self): self.labels_path, extra_attrs=self.extra_attrs ) - classes = None if classes_map is not None: - classes = _to_classes(classes_map) - - if classes is not None: - info["classes"] = classes + info["classes"] = _to_classes(classes_map) image_ids = _get_matching_image_ids( - classes, + classes_map, images, annotations, image_ids=self.image_ids, @@ -907,12 +913,11 @@ def export_sample(self, image_or_path, label, metadata=None): def close(self, *args): if self._dynamic_classes: - classes = sorted(self._classes) - labels_map_rev = _to_labels_map_rev(classes) + labels_map_rev = _to_labels_map_rev(sorted(self._classes)) for anno in self._annotations: anno["category_id"] = labels_map_rev[anno["category_id"]] - else: - classes = self.classes + elif self.categories is None: + labels_map_rev = _to_labels_map_rev(self.classes) _info = self.info or {} _date_created = datetime.now().replace(microsecond=0).isoformat() @@ -933,10 +938,10 @@ def close(self, *args): categories = [ { "id": i, - "name": l, + "name": c, "supercategory": None, } - for i, l in enumerate(classes) + for c, i in sorted(labels_map_rev.items(), key=lambda t: t[1]) ] labels = { @@ -1681,7 +1686,7 @@ def download_coco_dataset_split( if classes is not None: # Filter by specified classes all_ids, any_ids = _get_images_with_classes( - image_ids, annotations, classes, all_classes + image_ids, annotations, classes, all_classes_map ) else: all_ids = image_ids @@ -1846,7 +1851,7 @@ def _parse_include_license(include_license): def _get_matching_image_ids( - all_classes, + classes_map, images, annotations, image_ids=None, @@ -1862,7 +1867,7 @@ def _get_matching_image_ids( if classes is not None: all_ids, any_ids = _get_images_with_classes( - image_ids, annotations, classes, all_classes + image_ids, annotations, classes, classes_map ) else: all_ids = image_ids @@ -1930,7 +1935,7 @@ def _do_download(args): def _get_images_with_classes( - image_ids, annotations, target_classes, all_classes + image_ids, annotations, target_classes, classes_map ): if annotations is None: logger.warning("Dataset is unlabeled; ignoring classes requirement") @@ -1939,11 +1944,12 @@ def _get_images_with_classes( if etau.is_str(target_classes): target_classes = [target_classes] - bad_classes = [c for c in target_classes if c not in all_classes] + labels_map_rev = {c: i for i, c in classes_map.items()} + + bad_classes = [c for c in target_classes if c not in labels_map_rev] if bad_classes: raise ValueError("Unsupported classes: %s" % bad_classes) - labels_map_rev = _to_labels_map_rev(all_classes) class_ids = {labels_map_rev[c] for c in target_classes} all_ids = [] @@ -2029,7 +2035,7 @@ def _load_image_ids_json(json_path): def _to_labels_map_rev(classes): - return {c: i for i, c in enumerate(classes)} + return {c: i for i, c in enumerate(classes, 1)} def _to_classes(classes_map): diff --git a/fiftyone/utils/rerun.py b/fiftyone/utils/rerun.py new file mode 100644 index 0000000000..7b72e5601d --- /dev/null +++ b/fiftyone/utils/rerun.py @@ -0,0 +1,25 @@ +""" +Utilities for working with `Rerun `_. + +| Copyright 2017-2024, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +import fiftyone.core.fields as fof +import fiftyone.core.labels as fol +from fiftyone.core.odm import DynamicEmbeddedDocument + + +class RrdFile(DynamicEmbeddedDocument, fol._HasMedia): + """Class for storing a rerun data (rrd) file and its associated metadata. + + Args: + filepath (None): the path to the rrd file + version (None): the version of the rrd file + """ + + _MEDIA_FIELD = "filepath" + + filepath = fof.StringField() + version = fof.StringField() diff --git a/tests/unittests/import_export_tests.py b/tests/unittests/import_export_tests.py index 54798733f5..896429d8a7 100644 --- a/tests/unittests/import_export_tests.py +++ b/tests/unittests/import_export_tests.py @@ -1317,6 +1317,65 @@ def test_coco_detection_dataset(self): {c["id"] for c in categories2}, ) + # Alphabetized 1-based categories by default + + export_dir = self._new_dir() + + dataset.export( + export_dir=export_dir, + dataset_type=fo.types.COCODetectionDataset, + ) + + dataset2 = fo.Dataset.from_dir( + dataset_dir=export_dir, + dataset_type=fo.types.COCODetectionDataset, + label_types="detections", + label_field="predictions", + ) + categories2 = dataset2.info["categories"] + + self.assertListEqual([c["id"] for c in categories2], [1, 2]) + self.assertListEqual([c["name"] for c in categories2], ["cat", "dog"]) + + # Only load matching classes + + export_dir = self._new_dir() + + dataset.export( + export_dir=export_dir, + dataset_type=fo.types.COCODetectionDataset, + ) + + dataset2 = fo.Dataset.from_dir( + dataset_dir=export_dir, + dataset_type=fo.types.COCODetectionDataset, + label_types="detections", + label_field="predictions", + classes="cat", + only_matching=False, + ) + + self.assertEqual(len(dataset2), 2) + self.assertListEqual( + dataset2.distinct("predictions.detections.label"), + ["cat", "dog"], + ) + + dataset3 = fo.Dataset.from_dir( + dataset_dir=export_dir, + dataset_type=fo.types.COCODetectionDataset, + label_types="detections", + label_field="predictions", + classes="cat", + only_matching=True, + ) + + self.assertEqual(len(dataset3), 2) + self.assertListEqual( + dataset3.distinct("predictions.detections.label"), + ["cat"], + ) + @drop_datasets def test_voc_detection_dataset(self): dataset = self._make_dataset() @@ -1758,16 +1817,19 @@ def test_add_yolo_labels(self): @drop_datasets def test_add_coco_labels(self): dataset = self._make_dataset() + classes = dataset.distinct("predictions.detections.label") + categories = [{"id": i, "name": l} for i, l in enumerate(classes, 1)] export_dir = self._new_dir() dataset.export( export_dir=export_dir, dataset_type=fo.types.COCODetectionDataset, + categories=categories, ) coco_labels_path = os.path.join(export_dir, "labels.json") - fouc.add_coco_labels(dataset, "coco", coco_labels_path, classes) + fouc.add_coco_labels(dataset, "coco", coco_labels_path, categories) self.assertEqual( dataset.count_values("predictions.detections.label"), dataset.count_values("coco.detections.label"),