diff --git a/packages/client/src/components/TagTraining.component.tsx b/packages/client/src/components/TagTraining.component.tsx index a2a46d0e..3c877d99 100644 --- a/packages/client/src/components/TagTraining.component.tsx +++ b/packages/client/src/components/TagTraining.component.tsx @@ -1,9 +1,10 @@ import { DatasetsView } from './DatasetsView.component'; import { useState, useEffect, SetStateAction, Dispatch } from 'react'; -import { useGetDatasetsQuery } from '../graphql/dataset/dataset'; +import { useGetDatasetsByProjectLazyQuery } from '../graphql/dataset/dataset'; import { Dataset, Entry } from '../graphql/graphql'; import { GridColDef } from '@mui/x-data-grid'; import { Switch } from '@mui/material'; +import { useProject } from '../context/Project.context'; export interface TagTrainingComponentProps { setTrainingSet: Dispatch>; @@ -12,7 +13,14 @@ export interface TagTrainingComponentProps { export const TagTrainingComponent: React.FC = (props) => { const [datasets, setDatasets] = useState([]); - const getDatasetsResults = useGetDatasetsQuery(); + const { project } = useProject(); + const [getDatasetsQuery, getDatasetsResults] = useGetDatasetsByProjectLazyQuery(); + + useEffect(() => { + if (project) { + getDatasetsQuery({ variables: { project: project._id } }); + } + }, [project]); const trainingSet: Set = new Set(); const fullSet: Set = new Set(); @@ -67,7 +75,7 @@ export const TagTrainingComponent: React.FC = (props) // accessible by the current project useEffect(() => { if (getDatasetsResults.data) { - setDatasets(getDatasetsResults.data.getDatasets); + setDatasets(getDatasetsResults.data.getDatasetsByProject); } }, [getDatasetsResults.data]);