module Linear_function = struct

(********************)
(* Main definitions *)
(********************)

(* A linear function *)
let lin_f = new_definition `lin_f terms =
  ITLIST (\tm x. (FST tm) * (SND tm) + x) terms (&0)`;;
(* Example: let x = `lin_f [&1, x:real; &2, y:real]`;; *) (**********************) (* Theorems for lin_f *) (**********************) (* Basic properties of lin_f *)
let LIN_F_CONS = 
prove(`!h t. lin_f (CONS h t) = (FST h * SND h) + lin_f t`,
REWRITE_TAC[lin_f; ITLIST]);;
let LIN_F_EMPTY = 
prove(`lin_f [] = &0`,
REWRITE_TAC[lin_f; ITLIST]);;
let LIN_F_APPEND = 
prove(`!t1 t2. lin_f (APPEND t1 t2) = lin_f t1 + lin_f t2`,
LIST_INDUCT_TAC THEN REWRITE_TAC[APPEND] THENL [ REWRITE_TAC[LIN_F_EMPTY; REAL_ADD_LID]; ALL_TAC ] THEN GEN_TAC THEN ASM_REWRITE_TAC[LIN_F_CONS; REAL_ADD_ASSOC]);;
(* Sum of two lin_f *)
let LIN_F_SUM_HD = 
prove(`!v c1 t1 c2 t2 c. c1 + c2 = c ==> lin_f (CONS (c1,v) t1) + lin_f (CONS (c2,v) t2) = lin_f [c,v] + (lin_f t1 + lin_f t2)`,
REPEAT GEN_TAC THEN REWRITE_TAC[LIN_F_CONS; LIN_F_EMPTY] THEN DISCH_THEN (fun th -> REWRITE_TAC[SYM th]) THEN REAL_ARITH_TAC);;
let LIN_F_SUM_HD_ZERO = 
prove(`!v c1 t1 c2 t2. c1 + c2 = &0 ==> lin_f (CONS (c1,v) t1) + lin_f (CONS (c2,v) t2) = lin_f t1 + lin_f t2`,
REPEAT GEN_TAC THEN DISCH_THEN (MP_TAC o MATCH_MP LIN_F_SUM_HD) THEN DISCH_THEN (fun th -> REWRITE_TAC[th; LIN_F_CONS; LIN_F_EMPTY]) THEN REAL_ARITH_TAC);;
let LIN_F_SUM_EMPTY1 = 
prove(`!x. lin_f [] + lin_f x = lin_f x`,
REWRITE_TAC[LIN_F_EMPTY; REAL_ADD_LID]);;
let LIN_F_SUM_EMPTY2 = 
prove(`!x. lin_f x + lin_f [] = lin_f x`,
REWRITE_TAC[LIN_F_EMPTY; REAL_ADD_RID]);;
let LIN_F_SUM_MOVE1 = 
prove(`!h t x. lin_f (CONS h t) + lin_f x = lin_f [h] + (lin_f t + lin_f x)`,
REPEAT GEN_TAC THEN REWRITE_TAC[LIN_F_CONS; LIN_F_EMPTY] THEN REAL_ARITH_TAC);;
let LIN_F_SUM_MOVE2 = 
prove(`!h t x. lin_f x + lin_f (CONS h t) = lin_f [h] + (lin_f x + lin_f t)`,
REPEAT GEN_TAC THEN REWRITE_TAC[LIN_F_CONS; LIN_F_EMPTY] THEN REAL_ARITH_TAC);;
let LIN_F_ADD_SINGLE = 
prove(`!h t. lin_f [h] + lin_f t = lin_f (CONS h t)`,
REWRITE_TAC[LIN_F_CONS; LIN_F_EMPTY] THEN REAL_ARITH_TAC);;
let LIN_F_ADD_SING_RCANCEL = 
prove(`!c v t. c * v + lin_f (CONS (--c, v) t) = lin_f t`,
REWRITE_TAC[LIN_F_CONS; LIN_F_EMPTY] THEN REAL_ARITH_TAC);;
let LIN_F_ADD_SING_LCANCEL = 
prove(`!c v t. --c * v + lin_f (CONS (c, v) t) = lin_f t`,
REWRITE_TAC[LIN_F_CONS; LIN_F_EMPTY] THEN REAL_ARITH_TAC);;
(* Multiplication of lin_f *)
let LIN_F_MUL = 
prove(`!v t. v * lin_f t = lin_f (MAP (\tm. v * FST tm, SND tm) t)`,
REWRITE_TAC[lin_f] THEN GEN_TAC THEN LIST_INDUCT_TAC THENL [ REWRITE_TAC[ITLIST; MAP; REAL_MUL_RZERO]; ALL_TAC ] THEN REWRITE_TAC[ITLIST; MAP] THEN ASM_REWRITE_TAC[REAL_ADD_LDISTRIB; REAL_MUL_ASSOC]);;
let LIN_F_MUL_HD = 
prove(`!x c v t c1. x * c = c1 ==> x * lin_f (CONS (c,v) t) = lin_f [c1, v] + x * lin_f t`,
REWRITE_TAC[LIN_F_CONS; LIN_F_EMPTY] THEN REPEAT STRIP_TAC THEN ASM_REWRITE_TAC[REAL_ARITH `x * (c * v + a) = (x * c) * v + x * a:real`] THEN REAL_ARITH_TAC);;
let LIN_F_MUL_EMPTY = 
prove(`!x. x * lin_f [] = lin_f []`,
REWRITE_TAC[LIN_F_EMPTY] THEN REAL_ARITH_TAC);;
(* Theorems for converting sums into lin_f *)
let TO_LIN_F0 = 
prove(`!x. x = x + lin_f []`,
REWRITE_TAC[LIN_F_EMPTY; REAL_ADD_RID]);;
let TO_LIN_F1 = 
prove(`!a x. a * x = lin_f [a, x]`,
REWRITE_TAC[lin_f; ITLIST; REAL_ADD_RID]);;
let TO_LIN_F = 
prove(`!c v x t. (c * v + x) + lin_f t = x + lin_f (CONS (c, v) t) /\ c * v + lin_f t = lin_f (CONS (c, v) t)`,
REWRITE_TAC[LIN_F_CONS] THEN REAL_ARITH_TAC);;
end;;