diff --git a/stmhal/flash.c b/stmhal/flash.c
index 6ed3b04324fb3f005aabfe44374bd36dfb674be2..ce16fe271b9d1c56da0a60de585a03340f8f35d5 100644
--- a/stmhal/flash.c
+++ b/stmhal/flash.c
@@ -49,7 +49,81 @@ uint32_t flash_get_sector_info(uint32_t addr, uint32_t *start_addr, uint32_t *si
     return 0;
 }
 
+void flash_erase(uint32_t flash_dest, const uint32_t *src, uint32_t num_word32) {
+    // check there is something to write
+    if (num_word32 == 0) {
+        return;
+    }
+
+    // unlock
+    HAL_FLASH_Unlock();
+
+    // Clear pending flags (if any)
+    __HAL_FLASH_CLEAR_FLAG(FLASH_FLAG_EOP | FLASH_FLAG_OPERR | FLASH_FLAG_WRPERR |
+                           FLASH_FLAG_PGAERR | FLASH_FLAG_PGPERR|FLASH_FLAG_PGSERR);
+
+    // erase the sector(s)
+    FLASH_EraseInitTypeDef EraseInitStruct;
+    EraseInitStruct.TypeErase = TYPEERASE_SECTORS;
+    EraseInitStruct.VoltageRange = VOLTAGE_RANGE_3; // voltage range needs to be 2.7V to 3.6V
+    EraseInitStruct.Sector = flash_get_sector_info(flash_dest, NULL, NULL);
+    EraseInitStruct.NbSectors = flash_get_sector_info(flash_dest + 4 * num_word32 - 1, NULL, NULL) - EraseInitStruct.Sector + 1;
+    uint32_t SectorError = 0;
+    if (HAL_FLASHEx_Erase(&EraseInitStruct, &SectorError) != HAL_OK) {
+        // error occurred during sector erase
+        HAL_FLASH_Lock(); // lock the flash
+        return;
+    }
+}
+
+/*
+// erase the sector using an interrupt
+void flash_erase_it(uint32_t flash_dest, const uint32_t *src, uint32_t num_word32) {
+    // check there is something to write
+    if (num_word32 == 0) {
+        return;
+    }
+
+    // unlock
+    HAL_FLASH_Unlock();
+
+    // Clear pending flags (if any)
+    __HAL_FLASH_CLEAR_FLAG(FLASH_FLAG_EOP | FLASH_FLAG_OPERR | FLASH_FLAG_WRPERR |
+                           FLASH_FLAG_PGAERR | FLASH_FLAG_PGPERR|FLASH_FLAG_PGSERR);
+
+    // erase the sector(s)
+    FLASH_EraseInitTypeDef EraseInitStruct;
+    EraseInitStruct.TypeErase = TYPEERASE_SECTORS;
+    EraseInitStruct.VoltageRange = VOLTAGE_RANGE_3; // voltage range needs to be 2.7V to 3.6V
+    EraseInitStruct.Sector = flash_get_sector_info(flash_dest, NULL, NULL);
+    EraseInitStruct.NbSectors = flash_get_sector_info(flash_dest + 4 * num_word32 - 1, NULL, NULL) - EraseInitStruct.Sector + 1;
+    if (HAL_FLASHEx_Erase_IT(&EraseInitStruct) != HAL_OK) {
+        // error occurred during sector erase
+        HAL_FLASH_Lock(); // lock the flash
+        return;
+    }
+}
+*/
+
 void flash_write(uint32_t flash_dest, const uint32_t *src, uint32_t num_word32) {
+    // program the flash word by word
+    for (int i = 0; i < num_word32; i++) {
+        if (HAL_FLASH_Program(TYPEPROGRAM_WORD, flash_dest, *src) != HAL_OK) {
+            // error occurred during flash write
+            HAL_FLASH_Lock(); // lock the flash
+            return;
+        }
+        flash_dest += 4;
+        src += 1;
+    }
+
+    // lock the flash
+    HAL_FLASH_Lock();
+}
+
+/*
+ use erase, then write
+void flash_erase_and_write(uint32_t flash_dest, const uint32_t *src, uint32_t num_word32) {
     // check there is something to write
     if (num_word32 == 0) {
         return;
@@ -71,6 +145,7 @@ void flash_write(uint32_t flash_dest, const uint32_t *src, uint32_t num_word32)
     uint32_t SectorError = 0;
     if (HAL_FLASHEx_Erase(&EraseInitStruct, &SectorError) != HAL_OK) {
         // error occurred during sector erase
+        HAL_FLASH_Lock(); // lock the flash
         return;
     }
 
@@ -78,6 +153,7 @@ void flash_write(uint32_t flash_dest, const uint32_t *src, uint32_t num_word32)
     for (int i = 0; i < num_word32; i++) {
         if (HAL_FLASH_Program(TYPEPROGRAM_WORD, flash_dest, *src) != HAL_OK) {
             // error occurred during flash write
+            HAL_FLASH_Lock(); // lock the flash
             return;
         }
         flash_dest += 4;
@@ -87,3 +163,4 @@ void flash_write(uint32_t flash_dest, const uint32_t *src, uint32_t num_word32)
     // lock the flash
     HAL_FLASH_Lock();
 }
+*/
diff --git a/stmhal/flash.h b/stmhal/flash.h
index 33d31df7a69a7744819a60a22a057b9a9d29aa95..7900d2c1aeb6022f99f41ab0d19a1c83f2339b07 100644
--- a/stmhal/flash.h
+++ b/stmhal/flash.h
@@ -1,2 +1,3 @@
 uint32_t flash_get_sector_info(uint32_t addr, uint32_t *start_addr, uint32_t *size);
+void flash_erase(uint32_t flash_dest, const uint32_t *src, uint32_t num_word32);
 void flash_write(uint32_t flash_dest, const uint32_t *src, uint32_t num_word32);
diff --git a/stmhal/pybstdio.c b/stmhal/pybstdio.c
index d40cebf9470dc766e464851fb96c86734f26ba9e..baabfd120cbcb2cd8d1a0a19502c22ef5a9d8ef3 100644
--- a/stmhal/pybstdio.c
+++ b/stmhal/pybstdio.c
@@ -9,7 +9,6 @@
 #include "obj.h"
 #include "stream.h"
 #include "pybstdio.h"
-#include "storage.h"
 #include "usb.h"
 #include "usart.h"
 
@@ -50,9 +49,6 @@ int stdin_rx_chr(void) {
             return usart_rx_char(pyb_usart_global_debug);
         }
         __WFI();
-        if (storage_needs_flush()) {
-            storage_flush();
-        }
     }
 }
 
diff --git a/stmhal/pyexec.c b/stmhal/pyexec.c
index 298e58a5fd600c4a4a624b27ce13dfd9f3d92bdc..9821e6a92934d19383082d4bd5dcc5249aee6cb0 100644
--- a/stmhal/pyexec.c
+++ b/stmhal/pyexec.c
@@ -21,7 +21,6 @@
 #include "pybstdio.h"
 #include "readline.h"
 #include "pyexec.h"
-#include "storage.h"
 #include "usb.h"
 #include "build/py/py-version.h"
 
diff --git a/stmhal/stm32f4xx_it.c b/stmhal/stm32f4xx_it.c
index aee689d7e2001e6c6cbd2f9952906a50fc8cb9aa..06428a4c4fad37c3e6516c5ea793b29bc173da48 100644
--- a/stmhal/stm32f4xx_it.c
+++ b/stmhal/stm32f4xx_it.c
@@ -49,6 +49,7 @@
 #include "obj.h"
 #include "exti.h"
 #include "timer.h"
+#include "storage.h"
 
 /** @addtogroup STM32F4xx_HAL_Examples
   * @{
@@ -263,6 +264,19 @@ void OTG_XX_WKUP_IRQHandler(void)
 {
 }*/
 
+// Handle a flash (erase/program) interrupt.
+void FLASH_IRQHandler(void) {
+    // This calls the real flash IRQ handler, if needed
+    /*
+    uint32_t flash_cr = FLASH->CR;
+    if ((flash_cr & FLASH_IT_EOP) || (flash_cr & FLASH_IT_ERR)) {
+        HAL_FLASH_IRQHandler();
+    }
+    */
+    // This call the storage IRQ handler, to check if the flash cache needs flushing
+    storage_irq_handler();
+}
+
 /**
   * @brief  These functions handle the EXTI interrupt requests.
   * @param  None
diff --git a/stmhal/storage.c b/stmhal/storage.c
index 98d5ee7e70551f8a8a56346e30c8f3d11c976dd3..0348986f26e86b8432384292ea448c7a28f616c1 100644
--- a/stmhal/storage.c
+++ b/stmhal/storage.c
@@ -16,20 +16,22 @@
 #define FLASH_PART1_NUM_BLOCKS (224) // 16k+16k+16k+64k=112k
 #define FLASH_MEM_START_ADDR (0x08004000) // sector 1, 16k
 
+#define FLASH_FLAG_DIRTY        (1)
+#define FLASH_FLAG_FORCE_WRITE  (2)
+#define FLASH_FLAG_ERASED       (4)
 static bool flash_is_initialised = false;
-static bool flash_cache_dirty;
+static __IO uint8_t flash_flags = 0;
 static uint32_t flash_cache_sector_id;
 static uint32_t flash_cache_sector_start;
 static uint32_t flash_cache_sector_size;
 static uint32_t flash_tick_counter_last_write;
 
 static void flash_cache_flush(void) {
-    if (flash_cache_dirty) {
-        // sync the cache RAM buffer by writing it to the flash page
-        flash_write(flash_cache_sector_start, (const uint32_t*)CACHE_MEM_START_ADDR, flash_cache_sector_size / 4);
-        flash_cache_dirty = false;
-        // indicate a clean cache with LED off
-        led_state(PYB_LED_R1, 0);
+    if (flash_flags & FLASH_FLAG_DIRTY) {
+        flash_flags |= FLASH_FLAG_FORCE_WRITE;
+        while (flash_flags & FLASH_FLAG_DIRTY) {
+           NVIC->STIR = FLASH_IRQn;
+        }
     }
 }
 
@@ -44,9 +46,9 @@ static uint8_t *flash_cache_get_addr_for_write(uint32_t flash_addr) {
         flash_cache_sector_start = flash_sector_start;
         flash_cache_sector_size = flash_sector_size;
     }
-    flash_cache_dirty = true;
-    // indicate a dirty cache with LED on
-    led_state(PYB_LED_R1, 1);
+    flash_flags |= FLASH_FLAG_DIRTY;
+    led_state(PYB_LED_R1, 1); // indicate a dirty cache with LED on
+    flash_tick_counter_last_write = HAL_GetTick();
     return (uint8_t*)CACHE_MEM_START_ADDR + flash_addr - flash_sector_start;
 }
 
@@ -64,11 +66,17 @@ static uint8_t *flash_cache_get_addr_for_read(uint32_t flash_addr) {
 
 void storage_init(void) {
     if (!flash_is_initialised) {
-        flash_cache_dirty = false;
+        flash_flags = 0;
         flash_cache_sector_id = 0;
-        flash_is_initialised = true;
         flash_tick_counter_last_write = 0;
+        flash_is_initialised = true;
     }
+
+    // Enable the flash IRQ, which is used to also call our storage IRQ handler
+    // It needs to go at a higher priority than all those components that rely on
+    // the flash storage (eg higher than USB MSC).
+    HAL_NVIC_SetPriority(FLASH_IRQn, 1, 1);
+    HAL_NVIC_EnableIRQ(FLASH_IRQn);
 }
 
 uint32_t storage_get_block_size(void) {
@@ -79,9 +87,47 @@ uint32_t storage_get_block_count(void) {
     return FLASH_PART1_START_BLOCK + FLASH_PART1_NUM_BLOCKS;
 }
 
-bool storage_needs_flush(void) {
-    // wait 2 seconds after last write to flush
-    return flash_cache_dirty && sys_tick_has_passed(flash_tick_counter_last_write, 2000);
+void storage_irq_handler(void) {
+    if (!(flash_flags & FLASH_FLAG_DIRTY)) {
+        return;
+    }
+
+    // This code uses interrupts to erase the flash
+    /*
+    if (flash_erase_state == 0) {
+        flash_erase_it(flash_cache_sector_start, (const uint32_t*)CACHE_MEM_START_ADDR, flash_cache_sector_size / 4);
+        flash_erase_state = 1;
+        return;
+    }
+
+    if (flash_erase_state == 1) {
+        // wait for erase
+        // TODO add timeout
+        #define flash_erase_done() (__HAL_FLASH_GET_FLAG(FLASH_FLAG_BSY) == RESET)
+        if (!flash_erase_done()) {
+            return;
+        }
+        flash_erase_state = 2;
+    }
+    */
+
+    // This code erases the flash directly, waiting for it to finish
+    if (!(flash_flags & FLASH_FLAG_ERASED)) {
+        flash_erase(flash_cache_sector_start, (const uint32_t*)CACHE_MEM_START_ADDR, flash_cache_sector_size / 4);
+        flash_flags |= FLASH_FLAG_ERASED;
+        return;
+    }
+
+    // If not a forced write, wait at least 5 seconds after last write to flush
+    // On file close and flash unmount we get a forced write, so we can afford to wait a while
+    if ((flash_flags & FLASH_FLAG_FORCE_WRITE) || sys_tick_has_passed(flash_tick_counter_last_write, 5000)) {
+        // sync the cache RAM buffer by writing it to the flash page
+        flash_write(flash_cache_sector_start, (const uint32_t*)CACHE_MEM_START_ADDR, flash_cache_sector_size / 4);
+        // clear the flash flags now that we have a clean cache
+        flash_flags = 0;
+        // indicate a clean cache with LED off
+        led_state(PYB_LED_R1, 0);
+    }
 }
 
 void storage_flush(void) {
@@ -167,7 +213,6 @@ bool storage_write_block(const uint8_t *src, uint32_t block) {
         uint32_t flash_addr = FLASH_MEM_START_ADDR + (block - FLASH_PART1_START_BLOCK) * FLASH_BLOCK_SIZE;
         uint8_t *dest = flash_cache_get_addr_for_write(flash_addr);
         memcpy(dest, src, FLASH_BLOCK_SIZE);
-        flash_tick_counter_last_write = HAL_GetTick();
         return true;
 
     } else {
diff --git a/stmhal/storage.h b/stmhal/storage.h
index 4d153d2f695b33af3e982b37b356a6448fabb662..ae6c832281eaed8b895e71be891f8c759d7f8805 100644
--- a/stmhal/storage.h
+++ b/stmhal/storage.h
@@ -3,7 +3,7 @@
 void storage_init(void);
 uint32_t storage_get_block_size(void);
 uint32_t storage_get_block_count(void);
-bool storage_needs_flush(void);
+void storage_irq_handler(void);
 void storage_flush(void);
 bool storage_read_block(uint8_t *dest, uint32_t block);
 bool storage_write_block(const uint8_t *src, uint32_t block);
diff --git a/stmhal/timer.c b/stmhal/timer.c
index 5ea605039d3f6d6ecf508ad3e4a3f8388d5f45de..26524fbdd903ddbe9d655700fa4344ae4827717b 100644
--- a/stmhal/timer.c
+++ b/stmhal/timer.c
@@ -19,6 +19,7 @@
 // the interrupts to be dispatched, so they are all collected here.
 //
 // TIM3:
+//  - flash storage controller, to flush the cache
 //  - USB CDC interface, interval, to check for new data
 //  - LED 4, PWM to set the LED intensity
 //
@@ -29,14 +30,17 @@ TIM_HandleTypeDef TIM3_Handle;
 TIM_HandleTypeDef TIM5_Handle;
 TIM_HandleTypeDef TIM6_Handle;
 
+// Used to divide down TIM3 and periodically call the flash storage IRQ
+static uint32_t tim3_counter = 0;
+
 // TIM3 is set-up for the USB CDC interface
 void timer_tim3_init(void) {
     // set up the timer for USBD CDC
     __TIM3_CLK_ENABLE();
 
     TIM3_Handle.Instance = TIM3;
-    TIM3_Handle.Init.Period = (USBD_CDC_POLLING_INTERVAL*1000) - 1;
-    TIM3_Handle.Init.Prescaler = 84-1;
+    TIM3_Handle.Init.Period = (USBD_CDC_POLLING_INTERVAL*1000) - 1; // TIM3 fires every USBD_CDC_POLLING_INTERVAL ms
+    TIM3_Handle.Init.Prescaler = 84-1; // for System clock at 168MHz, TIM3 runs at 1MHz
     TIM3_Handle.Init.ClockDivision = 0;
     TIM3_Handle.Init.CounterMode = TIM_COUNTERMODE_UP;
     HAL_TIM_Base_Init(&TIM3_Handle);
@@ -105,6 +109,13 @@ void timer_tim6_init(uint freq) {
 void HAL_TIM_PeriodElapsedCallback(TIM_HandleTypeDef *htim) {
     if (htim == &TIM3_Handle) {
         USBD_CDC_HAL_TIM_PeriodElapsedCallback();
+
+        // Periodically raise a flash IRQ for the flash storage controller
+        if (tim3_counter++ >= 500 / USBD_CDC_POLLING_INTERVAL) {
+            tim3_counter = 0;
+            NVIC->STIR = FLASH_IRQn;
+        }
+
     } else if (htim == &TIM5_Handle) {
         servo_timer_irq_callback();
     }
diff --git a/stmhal/usbd_msc_storage.c b/stmhal/usbd_msc_storage.c
index 0225a2a23c4146cde581bb563ab6dc419194eda5..3599e289de6e28f929d2787712bc779d451ace5e 100644
--- a/stmhal/usbd_msc_storage.c
+++ b/stmhal/usbd_msc_storage.c
@@ -109,6 +109,12 @@ int8_t FLASH_STORAGE_StopUnit(uint8_t lun) {
     return 0;
 }
 
+int8_t FLASH_STORAGE_PreventAllowMediumRemoval(uint8_t lun, uint8_t param) {
+    // sync the flash so that the cache is cleared and the device can be unplugged/turned off
+    disk_ioctl(0, CTRL_SYNC, NULL);
+    return 0;
+}
+
 /**
   * @brief  Read data from the medium
   * @param  lun : logical unit number
@@ -146,7 +152,6 @@ int8_t FLASH_STORAGE_Write (uint8_t lun, uint8_t *buf, uint32_t blk_addr, uint16
         }
     }
     */
-    storage_flush(); // XXX hack for now so that the cache is always flushed
     return 0;
 }
 
@@ -165,6 +170,7 @@ const USBD_StorageTypeDef USBD_FLASH_STORAGE_fops = {
     FLASH_STORAGE_IsReady,
     FLASH_STORAGE_IsWriteProtected,
     FLASH_STORAGE_StopUnit,
+    FLASH_STORAGE_PreventAllowMediumRemoval,
     FLASH_STORAGE_Read,
     FLASH_STORAGE_Write,
     FLASH_STORAGE_GetMaxLun,
@@ -295,6 +301,10 @@ int8_t SDCARD_STORAGE_StopUnit(uint8_t lun) {
     return 0;
 }
 
+int8_t SDCARD_STORAGE_PreventAllowMediumRemoval(uint8_t lun, uint8_t param) {
+    return 0;
+}
+
 /**
   * @brief  Read data from the medium
   * @param  lun : logical unit number
@@ -340,6 +350,7 @@ const USBD_StorageTypeDef USBD_SDCARD_STORAGE_fops = {
     SDCARD_STORAGE_IsReady,
     SDCARD_STORAGE_IsWriteProtected,
     SDCARD_STORAGE_StopUnit,
+    SDCARD_STORAGE_PreventAllowMediumRemoval,
     SDCARD_STORAGE_Read,
     SDCARD_STORAGE_Write,
     SDCARD_STORAGE_GetMaxLun,
diff --git a/stmhal/usbdev/class/cdc_msc_hid/inc/usbd_cdc_msc_hid.h b/stmhal/usbdev/class/cdc_msc_hid/inc/usbd_cdc_msc_hid.h
index 934399493119d9af8421018c05eaf8bf964d6885..e1ae578e32d73543550e57316b6e72ea9234a6b2 100644
--- a/stmhal/usbdev/class/cdc_msc_hid/inc/usbd_cdc_msc_hid.h
+++ b/stmhal/usbdev/class/cdc_msc_hid/inc/usbd_cdc_msc_hid.h
@@ -54,6 +54,7 @@ typedef struct _USBD_STORAGE {
   int8_t (* IsReady) (uint8_t lun);
   int8_t (* IsWriteProtected) (uint8_t lun);
   int8_t (* StopUnit)(uint8_t lun);
+  int8_t (* PreventAllowMediumRemoval)(uint8_t lun, uint8_t param0);
   int8_t (* Read) (uint8_t lun, uint8_t *buf, uint32_t blk_addr, uint16_t blk_len);
   int8_t (* Write)(uint8_t lun, uint8_t *buf, uint32_t blk_addr, uint16_t blk_len);
   int8_t (* GetMaxLun)(void);
diff --git a/stmhal/usbdev/class/cdc_msc_hid/src/usbd_msc_scsi.c b/stmhal/usbdev/class/cdc_msc_hid/src/usbd_msc_scsi.c
index b00d1ae2ce38e7f315ccd4e8e1066101e10e4a71..60258d64d9fd12c928b463715c9f9ccdc842bfe4 100644
--- a/stmhal/usbdev/class/cdc_msc_hid/src/usbd_msc_scsi.c
+++ b/stmhal/usbdev/class/cdc_msc_hid/src/usbd_msc_scsi.c
@@ -472,6 +472,7 @@ static int8_t SCSI_AllowMediumRemoval(USBD_HandleTypeDef  *pdev, uint8_t lun, ui
 {
   USBD_MSC_BOT_HandleTypeDef  *hmsc = pdev->pClassData;   
   hmsc->bot_data_length = 0;
+  ((USBD_StorageTypeDef *)pdev->pUserData)->PreventAllowMediumRemoval(lun, params[0]);
   return 0;
 }