import { Button, Grid, Stack } from "@mui/material";
import { FormikTouched, FormikValues, useFormik } from "formik";
import _ from "lodash";
import React, { useCallback, useEffect, useMemo, useState } from "react";
import { useHistory, useLocation } from "react-router-dom";
import { DataFieldDef } from "utils/data-field/data-field-def";
import { ObjectFields, pickData, validateWithObjectDef } from "utils/data-field/object-field-def";
import { keysOf, mapRecord } from "utils/type-utils";
import { MapKeys, NestedPartial } from "utils/utility-types";
import { InputUiBuilder } from "../../helpers/input-ui-builder";

export interface WizardConfig<WizardData extends object, ContextData extends MapKeys<WizardData, ContextData>> {
  steps: { [StepKey in keyof WizardData]: WizardStepConfig<WizardData, ContextData, StepKey> };
}

export type WizardStepConfig<
  WizardData extends object,
  ContextData extends MapKeys<WizardData, ContextData>,
  StepKey extends keyof WizardData,
> = {
  stepDataDef: DataFieldDef<WizardData[StepKey]>;
  Component: WizardStepComponent<WizardData, ContextData, StepKey>;
  contextConfig: {
    cacheStepContext?: boolean;
    loadContext: WizardContextLoader<WizardData, ContextData, StepKey>;
  };
};

export type WizardStepComponent<
  WizardData extends object,
  ContextData extends MapKeys<WizardData, ContextData>,
  StepKey extends keyof WizardData,
> = (props: {
  stepData: Partial<WizardData[StepKey]> | undefined;
  stepContext: ContextData[StepKey];
  wizardState: WizardState<WizardData, ContextData>;
  registerDataGetter(getter: DataGetter<WizardData[StepKey], ContextData[StepKey]>): void;
}) => JSX.Element;

type DataGetter<StepData, StepContext> = (
  stepContext: StepContext,
) => ({ valid: true; data: StepData } | { valid: false; data: Partial<StepData> }) & { stepContext: StepContext };

export type WizardContextLoader<
  WizardData extends object,
  ContextData extends MapKeys<WizardData, ContextData>,
  StepKey extends keyof WizardData,
> = (
  stepData: Partial<WizardStepState<WizardData, ContextData, StepKey>>,
  wizardState: WizardState<WizardData, ContextData>,
) => Promise<WizardStepState<WizardData, ContextData, StepKey>>;

type WizardContextConfig<
  WizardData extends object,
  ContextData extends MapKeys<WizardData, ContextData>,
  StepKey extends keyof WizardData,
> = {
  cacheStepContext?: boolean;
  loadContext: WizardContextLoader<WizardData, ContextData, StepKey>;
};

export function makeWizardStep<
  WizardData extends object,
  ContextData extends MapKeys<WizardData, ContextData>,
  StepKey extends keyof WizardData,
>(
  stepDataDef: DataFieldDef<WizardData[StepKey]>,
  Component: WizardStepComponent<WizardData, ContextData, StepKey>,
  contextConfig: WizardContextConfig<WizardData, ContextData, StepKey>,
): WizardStepConfig<WizardData, ContextData, StepKey> {
  return {
    stepDataDef,
    Component,
    contextConfig: {
      // default caching to true
      cacheStepContext: contextConfig.cacheStepContext ?? true,
      loadContext: contextConfig.loadContext,
    },
  };
}

type WizardState<WizardData extends object, ContextData extends MapKeys<WizardData, ContextData>> = {
  wizardData: NestedPartial<WizardData>;
  contextData: Partial<ContextData>;
  stepKey: keyof WizardData;
};

export type WizardStepState<
  WizardData extends object,
  ContextData extends MapKeys<WizardData, ContextData>,
  StepKey extends keyof WizardData,
> = {
  stepData: Partial<WizardData[StepKey]>;
  stepContext: ContextData[StepKey];
};

export function Wizard<WizardData extends object, ContextData extends MapKeys<WizardData, ContextData>>({
  inputWizardData,
  config,
  onStepChange,
  onCancel,
  onSubmit,
}: {
  inputWizardData?: WizardData;
  config: WizardConfig<WizardData, ContextData>;
  onStepChange?: (stepKey: keyof WizardData) => void;
  onCancel: () => void;
  onSubmit: (finalData: WizardData, contextData: ContextData) => void;
}) {
  const urlLocation = useLocation();
  const urlHistory = useHistory();

  const readLocationCache = useCallback(() => {
    return (urlLocation.state ?? {}) as Partial<WizardState<WizardData, ContextData>>;
  }, [urlLocation.state]);
  const writeLocationCache: (key: keyof WizardData, state?: WizardState<WizardData, ContextData>) => void = useCallback(
    (key: keyof WizardData, data?: WizardState<WizardData, ContextData>) => {
      urlHistory.push(urlLocation.pathname + "#" + String(key), data);
    },
    [urlHistory, urlLocation.pathname],
  );

  const allKeys = keysOf(config.steps);
  const initWizardState: WizardState<WizardData, ContextData> = useMemo(() => {
    const stepIdxFromUrl = allKeys.findIndex(k => k == urlLocation.hash.slice(1));
    const lastIdx = Math.max(0, stepIdxFromUrl);
    const cached = readLocationCache();
    const wizardData: NestedPartial<WizardData> = cached?.wizardData ?? inputWizardData ?? {};
    const contextData: Partial<ContextData> = cached?.contextData ?? {};

    // fast forward to the first incomplete step
    const stepKey = (() => {
      for (const key of allKeys) {
        const stepDataDef = config.steps[key].stepDataDef;
        if (!stepDataDef.isValid(wizardData[key] ?? {})) {
          return key;
        }
      }
      return allKeys[lastIdx];
    })();

    return { wizardData, contextData, stepKey };
  }, [allKeys, config.steps, readLocationCache, urlLocation, inputWizardData]);

  const [wizardState, setWizardState] = useState(initWizardState);
  onStepChange?.(wizardState.stepKey);
  const prevKey = allKeys[allKeys.indexOf(wizardState.stepKey) - 1];
  const nextKey = allKeys[allKeys.indexOf(wizardState.stepKey) + 1];
  useEffect(() => {
    window.location.hash = String(wizardState.stepKey);
  }, [wizardState.stepKey]);

  // cached context loading promises, so we don't load the same context multiple times
  const contextLoadingPromises: { [K in keyof WizardData]?: Promise<WizardStepState<WizardData, ContextData, K>> } =
    useMemo(() => {
      return {};
    }, []);

  const loadContextData = useCallback(
    async function <StepKey extends keyof WizardData>(stepKey: StepKey) {
      const { wizardData, contextData } = _.cloneDeep(wizardState);
      let contextLoadingPromise = contextLoadingPromises[stepKey];
      if (!contextLoadingPromise) {
        contextLoadingPromise = contextLoadingPromises[stepKey] = config.steps[stepKey].contextConfig.loadContext(
          { stepData: wizardData[stepKey], stepContext: contextData[stepKey] },
          { wizardData, contextData, stepKey },
        );
      }
      const { stepData, stepContext } = await contextLoadingPromise;
      contextLoadingPromises[stepKey] = undefined;
      wizardData[stepKey] = stepData;
      contextData[stepKey] = stepContext;
      setWizardState({ wizardData, contextData, stepKey: wizardState.stepKey });
    },
    [contextLoadingPromises, config.steps, wizardState],
  );

  const pushState = useCallback(
    (
      newStepKey: keyof WizardData,
      stepData: Partial<WizardData[keyof WizardData]>,
      stepContext: ContextData[keyof WizardData],
    ) => {
      const wizardData = { ...wizardState.wizardData, [wizardState.stepKey]: stepData };
      const contextData = { ...wizardState.contextData, [wizardState.stepKey]: stepContext };

      // clear the current step context data if not cached
      if (!config.steps[wizardState.stepKey].contextConfig.cacheStepContext) {
        delete contextData[wizardState.stepKey];
      }

      const newWizardState = { wizardData, contextData, stepKey: newStepKey };
      writeLocationCache(newStepKey, newWizardState);
      setWizardState(newWizardState);
      onStepChange?.(newStepKey);
    },
    [wizardState, config.steps, writeLocationCache, onStepChange],
  );

  const doCancel = useCallback(() => {
    writeLocationCache(wizardState.stepKey, undefined);
    onCancel();
  }, [onCancel, writeLocationCache, wizardState.stepKey]);

  const dataDef = useMemo(() => {
    const fieldDefs = mapRecord(config.steps, (_, s) => s.stepDataDef) as unknown as ObjectFields<WizardData>;
    return Object.freeze(fieldDefs);
  }, [config.steps]);

  const doSubmit = useCallback(
    (newData: NestedPartial<WizardData>, contextData: Partial<ContextData>) => {
      if (validateWithObjectDef(newData as Partial<WizardData>, dataDef as Readonly<ObjectFields<WizardData>>)) {
        writeLocationCache(wizardState.stepKey, undefined);
        onSubmit(newData as WizardData, contextData as ContextData);
      }
    },
    [dataDef, onSubmit, writeLocationCache, wizardState.stepKey],
  );

  const stateClone = _.cloneDeep(wizardState);
  const stepContext = _.cloneDeep(stateClone.contextData[stateClone.stepKey]);
  return (
    <Stack spacing={2} justifyContent="space-between">
      {stepContext ? (
        <WizardStep<WizardData, ContextData, keyof WizardData>
          // the WizardStep component needs to be re-rendered when any of prev, current and next keys change
          key={[prevKey, wizardState.stepKey, nextKey].map(String).join("-")}
          stepData={stateClone.wizardData[stateClone.stepKey]}
          stepContext={stepContext}
          wizardState={stateClone}
          StepComponent={config.steps[wizardState.stepKey].Component}
          backAction={
            prevKey != undefined
              ? { label: "Back", onClick: (stepData, stepContext) => pushState(prevKey, stepData, stepContext) }
              : { label: "Cancel", onClick: doCancel }
          }
          nextAction={
            nextKey != undefined
              ? {
                  label: "Next",
                  onClick: (stepData, stepContext) => pushState(nextKey, stepData, stepContext),
                }
              : {
                  label: "Submit",
                  onClick: stepData => {
                    const mergedData = { ...wizardState.wizardData, [wizardState.stepKey]: stepData };
                    doSubmit(mergedData, stateClone.contextData);
                  },
                }
          }
        />
      ) : (
        <ContextLoader loadContextData={() => loadContextData(wizardState.stepKey)} />
      )}
    </Stack>
  );
}

function ContextLoader(props: { loadContextData: () => Promise<void> }) {
  props.loadContextData();
  return <div />;
}

function WizardStep<
  WizardData extends object,
  ContextData extends MapKeys<WizardData, ContextData>,
  StepKey extends keyof WizardData,
>({
  stepData,
  stepContext,
  wizardState,
  StepComponent,
  nextAction,
  backAction,
}: {
  stepData: Partial<WizardData[StepKey]> | undefined;
  stepContext: ContextData[StepKey];
  wizardState: WizardState<WizardData, ContextData>;
  StepComponent: WizardStepComponent<WizardData, ContextData, StepKey>;
  backAction: {
    label: string;
    // can go back with partial data
    onClick: (data: Partial<WizardData[StepKey]>, context: ContextData[StepKey]) => void;
  };
  nextAction: {
    label: string;
    // need full data to next
    onClick: (data: WizardData[StepKey], context: ContextData[StepKey]) => void;
  };
}) {
  // this is a workaround to get the data from the step component, dataGetter should only be called when
  // the step component is mounted
  const dataGetterHolder = useMemo(
    () => ({}) as { dataGetter?: DataGetter<WizardData[StepKey], ContextData[StepKey]> },
    [],
  );
  const registerDataGetter = useCallback(
    (getter: DataGetter<WizardData[StepKey], ContextData[StepKey]>) => {
      dataGetterHolder.dataGetter = getter;
    },
    [dataGetterHolder],
  );

  return (
    <Grid container gap={7}>
      <Grid item xs={12}>
        <StepComponent
          stepData={stepData}
          stepContext={stepContext}
          wizardState={wizardState}
          registerDataGetter={registerDataGetter}
        />
      </Grid>
      <Grid item direction="row" xs={12}>
        <Stack direction="row" spacing={2} justifyContent={"flex-end"}>
          <Button
            variant="outlined"
            onClick={useCallback(() => {
              const result = dataGetterHolder.dataGetter?.(stepContext);
              if (result) {
                // no need to check validity when going back.
                backAction.onClick(result.data, result.stepContext);
              }
            }, [backAction, dataGetterHolder, stepContext])}>
            {backAction.label}
          </Button>
          <Button
            variant="contained"
            onClick={useCallback(() => {
              const result = dataGetterHolder.dataGetter?.(stepContext);
              if (result?.valid) {
                nextAction.onClick(result.data, result.stepContext);
              }
              // Errors are only shown after a field is touched, so we need to touch every field
            }, [nextAction, dataGetterHolder, stepContext])}>
            {nextAction.label}
          </Button>
        </Stack>
      </Grid>
    </Grid>
  );
}

export function useWizardStepSetup<StepData extends object, StepContext>(
  initialData: StepData,
  objectDef: DataFieldDef<StepData>,
  registerDataGetter: (getter: DataGetter<StepData, StepContext>) => void,
) {
  if (objectDef.metaData.kind != "object") {
    throw new Error(`Expected 'object' field def, found ${objectDef.metaData.kind}`);
  }
  const validationSchema = objectDef.getSchema();
  const fieldDefs = objectDef.metaData.fieldDefs as Readonly<ObjectFields<StepData>>;
  const form = useFormik<StepData>({
    initialValues: pickData(fieldDefs, initialData),
    validationSchema,
    onSubmit: () => undefined,
  });
  registerDataGetter(stepContext => {
    const displayData = form.values;
    // touch ever field so their errors will be visible in the UI.
    const valid = validationSchema.isValidSync(displayData);
    const touched = mapFormTouched(objectDef as DataFieldDef<unknown>, form.values);
    form.setTouched(touched);

    if (valid) {
      return { valid, data: form.values as StepData, stepContext };
    } else {
      return { valid, data: form.values as Partial<StepData>, stepContext };
    }
  });
  const uiBuilder = useMemo(() => new InputUiBuilder<StepData>(fieldDefs, form), [fieldDefs, form]);
  return { form, uiBuilder };
}

function mapFormTouched(def: DataFieldDef<unknown>, values: FormikValues | undefined): FormikTouched<unknown> {
  switch (def.metaData.kind) {
    case "object": {
      const fieldDefs = def.metaData.fieldDefs as Readonly<ObjectFields<object>>;
      return mapRecord(fieldDefs, (k, d) => mapFormTouched(d as DataFieldDef<unknown>, values?.[k]));
    }
    case "union": {
      const kind = values?.["kind"];
      if (kind) {
        const valueDef = def.metaData.defMapping[kind];
        if (valueDef == "no-value") {
          return { kind: true };
        } else {
          return { kind: true, value: mapFormTouched(valueDef, values?.["value"]) };
        }
      } else {
        return {};
      }
    }
    case "array": {
      const { elementDef } = def.metaData;
      if (Array.isArray(values)) {
        return values.map(v => mapFormTouched(elementDef, v));
      } else {
        return [];
      }
    }
    default: {
      return true;
    }
  }
}
