import {z} from 'zod';
import {type ConjunctionInputNode} from '../../../../../types/metadata/inputs';

const LeafInputNode = z.discriminatedUnion('op', [
  z.object({
    op: z.literal('$eq'),
    valueType: z.enum(['auto', 'number', 'string', 'boolean']),
    key: z.string(),
    value: z.string(),
  }),
  z.object({
    op: z.literal('$ne'),
    valueType: z.enum(['auto', 'number', 'string', 'boolean']),
    key: z.string(),
    value: z.string(),
  }),
  z.object({
    op: z.literal('$gt'),
    valueType: z.literal('number'),
    key: z.string(),
    value: z.string(),
  }),
  z.object({
    op: z.literal('$gte'),
    valueType: z.literal('number'),
    key: z.string(),
    value: z.string(),
  }),
  z.object({
    op: z.literal('$lt'),
    valueType: z.literal('number'),
    key: z.string(),
    value: z.string(),
  }),
  z.object({
    op: z.literal('$lte'),
    valueType: z.literal('number'),
    key: z.string(),
    value: z.string(),
  }),
  z.object({
    op: z.literal('$in'),
    valueType: z.literal('array'),
    key: z.string(),
    value: z.string(),
  }),
  z.object({
    op: z.literal('$nin'),
    valueType: z.literal('array'),
    key: z.string(),
    value: z.string(),
  }),
]);

const queryModeOptions = ['vector', 'id'] as const;

const conjunctionInputNode: z.ZodType<ConjunctionInputNode> = z.object({
  operation: z.enum(['$and', '$or']),
  nodes: z.lazy(() => z.array(z.union([LeafInputNode, conjunctionInputNode]))),
});

const createVectorInputType = (dimensions = -1, forceInt = false) =>
  z
    .union([z.string().min(1, {message: 'Required'}), z.array(z.number())])
    .optional()
    .transform((maybeArr) => (Array.isArray(maybeArr) ? maybeArr.join(', ') : maybeArr))
    .transform((text) => text || '')
    .transform((text) => text.split(','))
    .refine((vectorText) => dimensions < 0 || vectorText.length === dimensions, {
      message: `Wrong dimensions, expected ${dimensions}`,
    })
    .transform((vectorText) =>
      vectorText.map((feature) => (forceInt ? parseInt(feature, 10) : parseFloat(feature))),
    )
    .refine((vector) => !vector || !vector?.some((feature) => Number.isNaN(feature)), {
      message: 'Not all dimensions are numbers',
    });

const valueTypes = z.enum(['auto', 'string', 'number', 'boolean', 'array']);
export type ValueTypes = z.infer<typeof valueTypes>;

const metadataInput = z.object({
  key: z.string(),
  value: z.string(),
  valueType: valueTypes,
});

export type MetadataInput = z.infer<typeof metadataInput>;

export const createUpsertSchema = (dimensions: number) =>
  z.object({
    vector: createVectorInputType(dimensions),
    id: z.string().min(1, {message: 'ID must not be empty'}),
    namespace: z.string().nullable().optional(),
    metadata: z.array(metadataInput),
    sparse: z
      .object({
        indices: createVectorInputType(/* dimensions= */ -1, /* forceInt= */ true),
        values: createVectorInputType(),
      })
      .refine(
        (data) => {
          if (data) {
            return data.indices?.length === data.values?.length;
          }
          return true;
        },
        (data) => {
          if (data && !data.indices) {
            return {message: 'Required', path: ['indices']};
          }
          if (data && !data.values) {
            return {message: 'Required', path: ['values']};
          }
          return {
            message: '* Please Fix: Number of indices does not match number of sparse values',
          };
        },
      )
      .optional(),
  });

export type UpsertSchemaType = z.infer<ReturnType<typeof createUpsertSchema>>;

export const listSchema = z.object({
  namespace: z.string().optional(),
  prefix: z.string(),
  limit: z
    .number()
    .min(1, {message: 'ID must be greater than 0'})
    .max(1000, {message: 'Max limit is 1000'}),
});
export type ListVectorsSchemaType = z.infer<typeof listSchema>;

export const createSchema = (dimensions: number) =>
  z
    .object({
      namespace: z.string().nullable().optional(),
      topK: z.coerce.number().positive().int().lte(1000),
      queryMode: z.enum(queryModeOptions),
      queryVectorText: createVectorInputType(dimensions),
      queryId: z.string().optional(),
      filters: z.array(conjunctionInputNode),
    })
    .refine(
      (data) => {
        if (data.queryMode === 'vector' && !data.queryVectorText) {
          return false;
        }
        if (data.queryMode === 'id' && !data.queryId) {
          return false;
        }
        return true;
      },
      (data) => {
        if (data.queryMode === 'vector') {
          return {message: 'Vector is required', path: ['queryVectorText']};
        }
        return {message: 'ID is required', path: ['queryId']};
      },
    );

export type SchemaType = z.infer<ReturnType<typeof createSchema>>;
