From 1a7cc3855d389e550edebdd7f91be51f443a9878 Mon Sep 17 00:00:00 2001 From: cbolles Date: Fri, 1 Mar 2024 14:29:20 -0500 Subject: [PATCH 1/5] Begin capturing training set logic --- packages/server/src/tag/models/training-set.ts | 14 ++++++++++++++ .../src/tag/resolvers/training-set.resolver.ts | 6 ++++++ .../src/tag/services/training-set.service.ts | 4 ++++ packages/server/src/tag/tag.module.ts | 10 ++++++++-- 4 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 packages/server/src/tag/models/training-set.ts create mode 100644 packages/server/src/tag/resolvers/training-set.resolver.ts create mode 100644 packages/server/src/tag/services/training-set.service.ts diff --git a/packages/server/src/tag/models/training-set.ts b/packages/server/src/tag/models/training-set.ts new file mode 100644 index 00000000..b2cf145c --- /dev/null +++ b/packages/server/src/tag/models/training-set.ts @@ -0,0 +1,14 @@ +import { Schema, Prop, SchemaFactory } from '@nestjs/mongoose'; +import { Document } from 'mongoose'; + +@Schema() +export class TrainingSet { + @Prop() + study: string; + + @Prop() + entries: string[]; +} + +export type TrainingSetDocument = TrainingSet & Document; +export const TrainingSetSchema = SchemaFactory.createForClass(TrainingSet); diff --git a/packages/server/src/tag/resolvers/training-set.resolver.ts b/packages/server/src/tag/resolvers/training-set.resolver.ts new file mode 100644 index 00000000..c984f3ea --- /dev/null +++ b/packages/server/src/tag/resolvers/training-set.resolver.ts @@ -0,0 +1,6 @@ +import { Resolver, Mutation } from '@nestjs/graphql'; + +@Resolver() +export class TrainingSetResolver { + +} diff --git a/packages/server/src/tag/services/training-set.service.ts b/packages/server/src/tag/services/training-set.service.ts new file mode 100644 index 00000000..bef36e4d --- /dev/null +++ b/packages/server/src/tag/services/training-set.service.ts @@ -0,0 +1,4 @@ +import { Injectable } from '@nestjs/common'; + +@Injectable() +export class TrainingSetService {} diff --git a/packages/server/src/tag/tag.module.ts b/packages/server/src/tag/tag.module.ts index 8885edaa..0d5e81a7 100644 --- a/packages/server/src/tag/tag.module.ts +++ b/packages/server/src/tag/tag.module.ts @@ -16,12 +16,16 @@ import { TagTransformer } from './services/tag-transformer.service'; import { FieldTransformerFactory } from './transformers/field-transformer-factory'; import { VideoFieldTransformer } from './transformers/video-field-transformer'; import { DatasetModule } from '../dataset/dataset.module'; +import { TrainingSet, TrainingSetSchema } from './models/training-set'; +import { TrainingSetResolver } from './resolvers/training-set.resolver'; +import { TrainingSetService } from './services/training-set.service'; @Module({ imports: [ MongooseModule.forFeature([ { name: Tag.name, schema: TagSchema }, - { name: VideoField.name, schema: VideoFieldSchema } + { name: VideoField.name, schema: VideoFieldSchema }, + { name: TrainingSet.name, schema: TrainingSetSchema } ]), StudyModule, EntryModule, @@ -38,7 +42,9 @@ import { DatasetModule } from '../dataset/dataset.module'; VideoFieldResolver, TagTransformer, FieldTransformerFactory, - VideoFieldTransformer + VideoFieldTransformer, + TrainingSetResolver, + TrainingSetService ] }) export class TagModule {} From a765041d57d6bdfd9d3d2814e713cd927f27d464 Mon Sep 17 00:00:00 2001 From: cbolles Date: Fri, 1 Mar 2024 14:46:57 -0500 Subject: [PATCH 2/5] Add in resolver for creating new entries --- packages/server/src/tag/models/tag.model.ts | 4 +++ .../tag/resolvers/training-set.resolver.ts | 31 ++++++++++++++++++- .../src/tag/services/training-set.service.ts | 16 +++++++++- 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/packages/server/src/tag/models/tag.model.ts b/packages/server/src/tag/models/tag.model.ts index d5da7bfb..78408a0c 100644 --- a/packages/server/src/tag/models/tag.model.ts +++ b/packages/server/src/tag/models/tag.model.ts @@ -41,6 +41,10 @@ export class Tag { @Prop() @Field({ description: 'If the tag is enabled as part of the study, way to disable certain tags' }) enabled: boolean; + + @Prop() + @Field({ description: 'If the tag is part of a training' }) + training: boolean; } export type TagDocument = Tag & Document; diff --git a/packages/server/src/tag/resolvers/training-set.resolver.ts b/packages/server/src/tag/resolvers/training-set.resolver.ts index c984f3ea..9f3fc534 100644 --- a/packages/server/src/tag/resolvers/training-set.resolver.ts +++ b/packages/server/src/tag/resolvers/training-set.resolver.ts @@ -1,6 +1,35 @@ -import { Resolver, Mutation } from '@nestjs/graphql'; +import { Resolver, Mutation, Args, ID } from '@nestjs/graphql'; +import { Entry } from '../../entry/models/entry.model'; +import { EntriesPipe } from '../../entry/pipes/entry.pipe'; +import { TokenPayload } from '../../jwt/token.dto'; +import { TokenContext } from '../../jwt/token.context'; +import { StudyPipe } from '../../study/pipes/study.pipe'; +import { Study } from '../../study/study.model'; +import { Tag } from '../models/tag.model' +import { TrainingSetService } from '../services/training-set.service'; +import { Inject, UnauthorizedException } from '@nestjs/common'; +import { CASBIN_PROVIDER } from 'src/permission/casbin.provider'; +import * as casbin from 'casbin'; +import { StudyPermissions } from 'src/permission/permissions/study'; @Resolver() export class TrainingSetResolver { + constructor( + private readonly trainingService: TrainingSetService, + @Inject(CASBIN_PROVIDER) private readonly enforcer: casbin.Enforcer + ) {} + + @Mutation(() => Boolean) + async createTrainingSet( + @Args('study', { type: () => ID }, StudyPipe) study: Study, + @Args('entries', { type: () => [ID] }, EntriesPipe) entries: Entry[], + @TokenContext() user: TokenPayload + ): Promise { + if (!(await this.enforcer.enforce(user.user_id, StudyPermissions.CREATE, study._id.toString()))) { + throw new UnauthorizedException('User cannot create a training set for this study'); + } + await this.trainingService.create(study, entries); + return true; + } } diff --git a/packages/server/src/tag/services/training-set.service.ts b/packages/server/src/tag/services/training-set.service.ts index bef36e4d..9f35932f 100644 --- a/packages/server/src/tag/services/training-set.service.ts +++ b/packages/server/src/tag/services/training-set.service.ts @@ -1,4 +1,18 @@ import { Injectable } from '@nestjs/common'; +import { InjectModel } from '@nestjs/mongoose'; +import { Model } from 'mongoose'; +import { Entry } from '../../entry/models/entry.model'; +import { Study } from '../../study/study.model'; +import { TrainingSet } from '../models/training-set'; @Injectable() -export class TrainingSetService {} +export class TrainingSetService { + constructor(@InjectModel(TrainingSet.name) private readonly trainingSetModel: Model) {} + + async create(study: Study, entries: Entry[]): Promise { + return this.trainingSetModel.create({ + study: study._id, + entries: entries.map((entry) => entry._id) + }) + } +} From ee0b0410c24dfa180adf5bad7036afe88f8aea7b Mon Sep 17 00:00:00 2001 From: cbolles Date: Fri, 1 Mar 2024 15:07:38 -0500 Subject: [PATCH 3/5] Creation of training set working --- .../src/components/TagTraining.component.tsx | 2 + packages/client/src/graphql/graphql.ts | 9 +++++ packages/client/src/graphql/tag/tag.graphql | 5 +++ packages/client/src/graphql/tag/tag.ts | 40 +++++++++++++++++++ .../client/src/pages/studies/NewStudy.tsx | 19 ++++++--- .../tag/resolvers/training-set.resolver.ts | 5 ++- 6 files changed, 73 insertions(+), 7 deletions(-) diff --git a/packages/client/src/components/TagTraining.component.tsx b/packages/client/src/components/TagTraining.component.tsx index 3c877d99..d6aeef2b 100644 --- a/packages/client/src/components/TagTraining.component.tsx +++ b/packages/client/src/components/TagTraining.component.tsx @@ -36,6 +36,7 @@ export const TagTrainingComponent: React.FC = (props) onLoad={(_entry) => {}} add={(entry) => { trainingSet.add(entry._id); + console.log(trainingSet); props.setTrainingSet(Array.from(trainingSet)); }} remove={(entry) => { @@ -59,6 +60,7 @@ export const TagTrainingComponent: React.FC = (props) }} add={(entry) => { fullSet.add(entry._id); + console.log(fullSet); props.setTaggingSet(Array.from(fullSet)); }} remove={(entry) => { diff --git a/packages/client/src/graphql/graphql.ts b/packages/client/src/graphql/graphql.ts index e8007aa6..486a8f49 100644 --- a/packages/client/src/graphql/graphql.ts +++ b/packages/client/src/graphql/graphql.ts @@ -112,6 +112,7 @@ export type Mutation = { createOrganization: Organization; createStudy: Study; createTags: Array; + createTrainingSet: Scalars['Boolean']['output']; createUploadSession: UploadSession; deleteEntry: Scalars['Boolean']['output']; deleteProject: Scalars['Boolean']['output']; @@ -193,6 +194,12 @@ export type MutationCreateTagsArgs = { }; +export type MutationCreateTrainingSetArgs = { + entries: Array; + study: Scalars['ID']['input']; +}; + + export type MutationCreateUploadSessionArgs = { dataset: Scalars['ID']['input']; }; @@ -507,6 +514,8 @@ export type Tag = { /** Way to rank tags based on order to be tagged */ order: Scalars['Float']['output']; study: Study; + /** If the tag is part of a training */ + training: Scalars['Boolean']['output']; /** The user assigned to the tag */ user?: Maybe; }; diff --git a/packages/client/src/graphql/tag/tag.graphql b/packages/client/src/graphql/tag/tag.graphql index d30a7514..2efa1006 100644 --- a/packages/client/src/graphql/tag/tag.graphql +++ b/packages/client/src/graphql/tag/tag.graphql @@ -4,6 +4,10 @@ mutation createTags($study: ID!, $entries: [ID!]!) { } } +mutation createTrainingSet($study: ID!, $entries: [ID!]!) { + createTrainingSet(study: $study, entries: $entries) +} + mutation setEntryEnabled($study: ID!, $entry: ID!, $enabled: Boolean!) { setEntryEnabled(study: $study, entry: $entry, enabled: $enabled) } @@ -60,3 +64,4 @@ query getTags($study: ID!) { complete } } + diff --git a/packages/client/src/graphql/tag/tag.ts b/packages/client/src/graphql/tag/tag.ts index d3c4b3b5..6417d986 100644 --- a/packages/client/src/graphql/tag/tag.ts +++ b/packages/client/src/graphql/tag/tag.ts @@ -13,6 +13,14 @@ export type CreateTagsMutationVariables = Types.Exact<{ export type CreateTagsMutation = { __typename?: 'Mutation', createTags: Array<{ __typename?: 'Tag', _id: string }> }; +export type CreateTrainingSetMutationVariables = Types.Exact<{ + study: Types.Scalars['ID']['input']; + entries: Array | Types.Scalars['ID']['input']; +}>; + + +export type CreateTrainingSetMutation = { __typename?: 'Mutation', createTrainingSet: boolean }; + export type SetEntryEnabledMutationVariables = Types.Exact<{ study: Types.Scalars['ID']['input']; entry: Types.Scalars['ID']['input']; @@ -96,6 +104,38 @@ export function useCreateTagsMutation(baseOptions?: Apollo.MutationHookOptions; export type CreateTagsMutationResult = Apollo.MutationResult; export type CreateTagsMutationOptions = Apollo.BaseMutationOptions; +export const CreateTrainingSetDocument = gql` + mutation createTrainingSet($study: ID!, $entries: [ID!]!) { + createTrainingSet(study: $study, entries: $entries) +} + `; +export type CreateTrainingSetMutationFn = Apollo.MutationFunction; + +/** + * __useCreateTrainingSetMutation__ + * + * To run a mutation, you first call `useCreateTrainingSetMutation` within a React component and pass it any options that fit your needs. + * When your component renders, `useCreateTrainingSetMutation` returns a tuple that includes: + * - A mutate function that you can call at any time to execute the mutation + * - An object with fields that represent the current status of the mutation's execution + * + * @param baseOptions options that will be passed into the mutation, supported options are listed on: https://www.apollographql.com/docs/react/api/react-hooks/#options-2; + * + * @example + * const [createTrainingSetMutation, { data, loading, error }] = useCreateTrainingSetMutation({ + * variables: { + * study: // value for 'study' + * entries: // value for 'entries' + * }, + * }); + */ +export function useCreateTrainingSetMutation(baseOptions?: Apollo.MutationHookOptions) { + const options = {...defaultOptions, ...baseOptions} + return Apollo.useMutation(CreateTrainingSetDocument, options); + } +export type CreateTrainingSetMutationHookResult = ReturnType; +export type CreateTrainingSetMutationResult = Apollo.MutationResult; +export type CreateTrainingSetMutationOptions = Apollo.BaseMutationOptions; export const SetEntryEnabledDocument = gql` mutation setEntryEnabled($study: ID!, $entry: ID!, $enabled: Boolean!) { setEntryEnabled(study: $study, entry: $entry, enabled: $enabled) diff --git a/packages/client/src/pages/studies/NewStudy.tsx b/packages/client/src/pages/studies/NewStudy.tsx index 7459b73a..e1f29dd2 100644 --- a/packages/client/src/pages/studies/NewStudy.tsx +++ b/packages/client/src/pages/studies/NewStudy.tsx @@ -5,11 +5,11 @@ import { TagTrainingComponent } from '../../components/TagTraining.component'; import { useState, useEffect } from 'react'; import { StudyCreate, TagSchema } from '../../graphql/graphql'; import { PartialStudyCreate } from '../../types/study'; -import { CreateStudyDocument } from '../../graphql/study/study'; +import { CreateStudyDocument, CreateStudyMutation, CreateStudyMutationVariables } from '../../graphql/study/study'; import { useProject } from '../../context/Project.context'; import { useStudy } from '../../context/Study.context'; import { useApolloClient } from '@apollo/client'; -import { CreateTagsDocument } from '../../graphql/tag/tag'; +import { CreateTagsDocument, CreateTrainingSetDocument, CreateTagsMutationVariables, CreateTagsMutation, CreateTrainingSetMutation, CreateTrainingSetMutationVariables } from '../../graphql/tag/tag'; import { useTranslation } from 'react-i18next'; import { TagFieldFragmentSchema, TagField } from '../../components/tagbuilder/TagProvider'; @@ -20,7 +20,7 @@ export const NewStudy: React.FC = () => { const [tagSchema, setTagSchema] = useState(null); const { project } = useProject(); const { updateStudies } = useStudy(); - const [_trainingSet, setTrainingSet] = useState([]); + const [trainingSet, setTrainingSet] = useState([]); const [taggingSet, setTaggingSet] = useState([]); const apolloClient = useApolloClient(); // The different fields that make up the tag schema @@ -70,7 +70,7 @@ export const NewStudy: React.FC = () => { }; // Make the new study - const result = await apolloClient.mutate({ + const result = await apolloClient.mutate({ mutation: CreateStudyDocument, variables: { study: study } }); @@ -81,10 +81,19 @@ export const NewStudy: React.FC = () => { } // Create the corresponding tags - await apolloClient.mutate({ + await apolloClient.mutate({ mutation: CreateTagsDocument, variables: { study: result.data.createStudy._id, entries: taggingSet } }); + + // Create the training set + await apolloClient.mutate({ + mutation: CreateTrainingSetDocument, + variables: { + study: result.data.createStudy._id, + entries: trainingSet + } + }); updateStudies(); } setActiveStep((prevActiveStep: number) => prevActiveStep + 1); diff --git a/packages/server/src/tag/resolvers/training-set.resolver.ts b/packages/server/src/tag/resolvers/training-set.resolver.ts index 9f3fc534..66a59d1c 100644 --- a/packages/server/src/tag/resolvers/training-set.resolver.ts +++ b/packages/server/src/tag/resolvers/training-set.resolver.ts @@ -5,13 +5,14 @@ import { TokenPayload } from '../../jwt/token.dto'; import { TokenContext } from '../../jwt/token.context'; import { StudyPipe } from '../../study/pipes/study.pipe'; import { Study } from '../../study/study.model'; -import { Tag } from '../models/tag.model' import { TrainingSetService } from '../services/training-set.service'; -import { Inject, UnauthorizedException } from '@nestjs/common'; +import { Inject, UnauthorizedException, UseGuards } from '@nestjs/common'; import { CASBIN_PROVIDER } from 'src/permission/casbin.provider'; import * as casbin from 'casbin'; import { StudyPermissions } from 'src/permission/permissions/study'; +import { JwtAuthGuard } from '../../jwt/jwt.guard'; +@UseGuards(JwtAuthGuard) @Resolver() export class TrainingSetResolver { From 05d7a5c645c7a310824753325b307b8fa7dc2835 Mon Sep 17 00:00:00 2001 From: cbolles Date: Fri, 1 Mar 2024 15:23:16 -0500 Subject: [PATCH 4/5] Improved selection of entries --- .../src/components/TagTraining.component.tsx | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/packages/client/src/components/TagTraining.component.tsx b/packages/client/src/components/TagTraining.component.tsx index d6aeef2b..b50c1a4f 100644 --- a/packages/client/src/components/TagTraining.component.tsx +++ b/packages/client/src/components/TagTraining.component.tsx @@ -15,6 +15,8 @@ export const TagTrainingComponent: React.FC = (props) const [datasets, setDatasets] = useState([]); const { project } = useProject(); const [getDatasetsQuery, getDatasetsResults] = useGetDatasetsByProjectLazyQuery(); + const [trainingSet, setTrainingSet] = useState([]); + const [taggingSet, setTaggingSet] = useState([]); useEffect(() => { if (project) { @@ -22,9 +24,6 @@ export const TagTrainingComponent: React.FC = (props) } }, [project]); - const trainingSet: Set = new Set(); - const fullSet: Set = new Set(); - const additionalColumns: GridColDef[] = [ { field: 'training', @@ -35,13 +34,10 @@ export const TagTrainingComponent: React.FC = (props) startingValue={false} onLoad={(_entry) => {}} add={(entry) => { - trainingSet.add(entry._id); - console.log(trainingSet); - props.setTrainingSet(Array.from(trainingSet)); + setTrainingSet([...trainingSet, entry._id]); }} remove={(entry) => { - trainingSet.delete(entry._id); - props.setTrainingSet(Array.from(trainingSet)); + setTrainingSet(trainingSet.filter((entryID) => entryID != entry._id)); }} entry={params.row} /> @@ -53,19 +49,13 @@ export const TagTrainingComponent: React.FC = (props) width: 200, renderCell: (params) => ( { - fullSet.add(entry._id); - props.setTaggingSet(Array.from(fullSet)); - }} + startingValue={false} + onLoad={(_entry) => {}} add={(entry) => { - fullSet.add(entry._id); - console.log(fullSet); - props.setTaggingSet(Array.from(fullSet)); + setTaggingSet([...taggingSet, entry._id]); }} remove={(entry) => { - fullSet.delete(entry._id); - props.setTaggingSet(Array.from(fullSet)); + setTaggingSet(taggingSet.filter((entryID) => entryID != entry._id)); }} entry={params.row} /> @@ -73,6 +63,16 @@ export const TagTrainingComponent: React.FC = (props) } ]; + useEffect(() => { + const entries = Array.from(new Set(taggingSet)); + props.setTaggingSet(entries); + }, [taggingSet]); + + useEffect(() => { + const entries = Array.from(new Set(trainingSet)); + props.setTrainingSet(entries); + }, [trainingSet]) + // TODO: In the future, the datasets retrieved should only be datasets // accessible by the current project useEffect(() => { From 1ef99b9150341399dd013599be1cd09f98fdb6b1 Mon Sep 17 00:00:00 2001 From: cbolles Date: Fri, 1 Mar 2024 15:24:19 -0500 Subject: [PATCH 5/5] Fix formatting --- packages/client/src/components/TagTraining.component.tsx | 2 +- packages/client/src/pages/studies/NewStudy.tsx | 9 ++++++++- .../server/src/tag/resolvers/training-set.resolver.ts | 1 - packages/server/src/tag/services/training-set.service.ts | 2 +- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/packages/client/src/components/TagTraining.component.tsx b/packages/client/src/components/TagTraining.component.tsx index b50c1a4f..a8f0487e 100644 --- a/packages/client/src/components/TagTraining.component.tsx +++ b/packages/client/src/components/TagTraining.component.tsx @@ -71,7 +71,7 @@ export const TagTrainingComponent: React.FC = (props) useEffect(() => { const entries = Array.from(new Set(trainingSet)); props.setTrainingSet(entries); - }, [trainingSet]) + }, [trainingSet]); // TODO: In the future, the datasets retrieved should only be datasets // accessible by the current project diff --git a/packages/client/src/pages/studies/NewStudy.tsx b/packages/client/src/pages/studies/NewStudy.tsx index e1f29dd2..24023038 100644 --- a/packages/client/src/pages/studies/NewStudy.tsx +++ b/packages/client/src/pages/studies/NewStudy.tsx @@ -9,7 +9,14 @@ import { CreateStudyDocument, CreateStudyMutation, CreateStudyMutationVariables import { useProject } from '../../context/Project.context'; import { useStudy } from '../../context/Study.context'; import { useApolloClient } from '@apollo/client'; -import { CreateTagsDocument, CreateTrainingSetDocument, CreateTagsMutationVariables, CreateTagsMutation, CreateTrainingSetMutation, CreateTrainingSetMutationVariables } from '../../graphql/tag/tag'; +import { + CreateTagsDocument, + CreateTrainingSetDocument, + CreateTagsMutationVariables, + CreateTagsMutation, + CreateTrainingSetMutation, + CreateTrainingSetMutationVariables +} from '../../graphql/tag/tag'; import { useTranslation } from 'react-i18next'; import { TagFieldFragmentSchema, TagField } from '../../components/tagbuilder/TagProvider'; diff --git a/packages/server/src/tag/resolvers/training-set.resolver.ts b/packages/server/src/tag/resolvers/training-set.resolver.ts index 66a59d1c..a8cee47b 100644 --- a/packages/server/src/tag/resolvers/training-set.resolver.ts +++ b/packages/server/src/tag/resolvers/training-set.resolver.ts @@ -15,7 +15,6 @@ import { JwtAuthGuard } from '../../jwt/jwt.guard'; @UseGuards(JwtAuthGuard) @Resolver() export class TrainingSetResolver { - constructor( private readonly trainingService: TrainingSetService, @Inject(CASBIN_PROVIDER) private readonly enforcer: casbin.Enforcer diff --git a/packages/server/src/tag/services/training-set.service.ts b/packages/server/src/tag/services/training-set.service.ts index 9f35932f..774a0531 100644 --- a/packages/server/src/tag/services/training-set.service.ts +++ b/packages/server/src/tag/services/training-set.service.ts @@ -13,6 +13,6 @@ export class TrainingSetService { return this.trainingSetModel.create({ study: study._id, entries: entries.map((entry) => entry._id) - }) + }); } }